Repository: visinf/irr
Branch: master
Commit: dacd07b1dc96
Files: 89
Total size: 240.7 MB
Directory structure:
gitextract_pwdsdpvy/
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── augmentations.py
├── commandline.py
├── configuration.py
├── datasets/
│ ├── __init__.py
│ ├── common.py
│ ├── flyingThings3D.py
│ ├── flyingchairs.py
│ ├── flyingchairsOcc.py
│ ├── kitti_combined.py
│ ├── sintel.py
│ └── transforms.py
├── flyingchairsocc/
│ └── README.md
├── install.sh
├── logger.py
├── losses.py
├── main.py
├── models/
│ ├── IRR_FlowNet.py
│ ├── IRR_PWC.py
│ ├── __init__.py
│ ├── correlation_package/
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── correlation_package_cu9/
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── flownet1s.py
│ ├── flownet1s_irr.py
│ ├── flownet1s_irr_bi.py
│ ├── flownet1s_irr_occ.py
│ ├── flownet1s_irr_occ_bi.py
│ ├── flownet_modules.py
│ ├── irr_modules.py
│ ├── pwc_modules.py
│ ├── pwcnet.py
│ ├── pwcnet_bi.py
│ ├── pwcnet_irr.py
│ ├── pwcnet_irr_bi.py
│ ├── pwcnet_irr_occ.py
│ ├── pwcnet_irr_occ_bi.py
│ ├── pwcnet_occ.py
│ └── pwcnet_occ_bi.py
├── optim/
│ └── __init__.py
├── runtime.py
├── saved_check_point/
│ └── pwcnet/
│ ├── IRR-PWC_flyingchairsOcc/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_kitti/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_sintel/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_things3d/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── PWCNet/
│ │ └── checkpoint_best.ckpt
│ └── PWCNet-irr/
│ └── checkpoint_best.ckpt
├── scripts/
│ ├── IRR-FlowNet_flyingChairsOcc.sh
│ ├── IRR-PWC_flyingChairsOcc.sh
│ ├── IRR-PWC_kitti_train.sh
│ ├── IRR-PWC_kitti_train_full.sh
│ ├── IRR-PWC_sintel_train.sh
│ ├── IRR-PWC_sintel_train_full.sh
│ ├── IRR-PWC_things3d.sh
│ ├── flownet1s.sh
│ ├── flownet1s_irr1.sh
│ ├── flownet1s_irr2.sh
│ ├── pwcnet.sh
│ ├── pwcnet_irr.sh
│ └── validation/
│ ├── IRR-FlowNet_flyingChairs.sh
│ ├── IRR-PWC_flyingChairs.sh
│ ├── IRR-PWC_kitti.sh
│ ├── IRR-PWC_sintel.sh
│ ├── IRR-PWC_things3d.sh
│ ├── flownet1s.sh
│ ├── flownet1s_irr1.sh
│ ├── flownet1s_irr2.sh
│ ├── pwcnet.sh
│ └── pwcnet_irr.sh
├── tools.py
└── utils/
├── __init__.py
├── flow.py
└── interpolation.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
*.so
*.o
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Iterative Residual Refinement
for Joint Optical Flow and Occlusion Estimation
This repository is the PyTorch implementation of the paper:
**Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation (CVPR 2019)**
[Junhwa Hur](https://sites.google.com/site/hurjunhwa) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/team_members/sroth/sroth.en.jsp)
Department of Computer Science, TU Darmstadt
[[Preprint]](https://arxiv.org/pdf/1904.05290.pdf) [[Proceeding]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Hur_Iterative_Residual_Refinement_for_Joint_Optical_Flow_and_Occlusion_Estimation_CVPR_2019_paper.pdf) [[Supplemental]](http://openaccess.thecvf.com/content_CVPR_2019/supplemental/Hur_Iterative_Residual_Refinement_CVPR_2019_supplemental.pdf)
Please cite the paper below if you find our paper and source codes are useful.
@inproceedings{Hur:2019:IRR,
Author = {Junhwa Hur and Stefan Roth},
Booktitle = {CVPR},
Title = {Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation},
Year = {2019}
}
Contact: junhwa.hur[at]visinf.tu-darmstadt.de
## Getting started
This code has been orginally developed under Anaconda(Python 3.6), PyTorch 0.4.1 and CUDA 8.0 on Ubuntu 16.04.
1. Please install the followings:
- Anaconda
- PyTorch (now compatible with __PyTorch 1.5.0__)
- tqdm (`conda install -c conda-forge tqdm==4.40.0`)
- (any missing packages that the code requires)
2. The datasets used for this projects are followings:
- [FlyingChairsOcc dataset](https://github.com/visinf/irr/tree/master/flyingchairsocc)
- [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
- [MPI Sintel Dataset](http://sintel.is.tue.mpg.de/downloads) + [revised occlusion GT](https://download.visinf.tu-darmstadt.de/data/flyingchairs_occ/occlusions_rev.zip)
- [KITTI Optical Flow 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) and [KITTI Optical Flow 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=flow)
## Training
The `scripts` folder contains training scripts of experiments demonstrated in the paper.
To train the model, you can simply run the script file, e.g., `./IRR-PWC_flyingChairsOcc.sh`.
In script files, please configure your own experiment directory (EXPERIMENTS_HOME) and dataset directory in your local system (e.g., SINTEL_HOME or KITTI_HOME).
## Pretrained Models
The `saved_check_point` contains the pretrained models of *i)* baseline, *ii)* baseline + irr, and *iii)* full models.
Additional pretrained models in the ablations study (Table 1 in the main paper) and their training scripts are available upon request.
## Inference
The scripts for testing the pre-trained models are located in `scripts/validation`.
## Acknowledgement
Portions of the source code (e.g., training pipeline, runtime, argument parser, and logger) are from [Jochen Gast](https://scholar.google.com/citations?user=tmRcFacAAAAJ&hl=en)
================================================
FILE: __init__.py
================================================
================================================
FILE: augmentations.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from utils.interpolation import Interp2, Interp2MaskBinary
from utils.interpolation import Meshgrid
import numpy as np
def denormalize_coords(xx, yy, width, height):
""" scale indices from [-1, 1] to [0, width/height] """
xx = 0.5 * (width - 1.0) * (xx.float() + 1.0)
yy = 0.5 * (height - 1.0) * (yy.float() + 1.0)
return xx, yy
def normalize_coords(xx, yy, width, height):
""" scale indices from [0, width/height] to [-1, 1] """
xx = (2.0 / (width - 1.0)) * xx.float() - 1.0
yy = (2.0 / (height - 1.0)) * yy.float() - 1.0
return xx, yy
def apply_transform_to_params(theta0, theta_transform):
a1 = theta0[:, 0]
a2 = theta0[:, 1]
a3 = theta0[:, 2]
a4 = theta0[:, 3]
a5 = theta0[:, 4]
a6 = theta0[:, 5]
#
b1 = theta_transform[:, 0]
b2 = theta_transform[:, 1]
b3 = theta_transform[:, 2]
b4 = theta_transform[:, 3]
b5 = theta_transform[:, 4]
b6 = theta_transform[:, 5]
#
c1 = a1 * b1 + a4 * b2
c2 = a2 * b1 + a5 * b2
c3 = b3 + a3 * b1 + a6 * b2
c4 = a1 * b4 + a4 * b5
c5 = a2 * b4 + a5 * b5
c6 = b6 + a3 * b4 + a6 * b5
#
new_theta = torch.stack([c1, c2, c3, c4, c5, c6], dim=1)
return new_theta
class _IdentityParams(nn.Module):
def __init__(self):
super(_IdentityParams, self).__init__()
self._batch_size = 0
self.register_buffer("_o", torch.FloatTensor())
self.register_buffer("_i", torch.FloatTensor())
def _update(self, batch_size):
torch.zeros([batch_size, 1], out=self._o)
torch.ones([batch_size, 1], out=self._i)
return torch.cat([self._i, self._o, self._o, self._o, self._i, self._o], dim=1)
def forward(self, batch_size):
if self._batch_size != batch_size:
self._identity_params = self._update(batch_size)
self._batch_size = batch_size
return self._identity_params
class RandomMirror(nn.Module):
def __init__(self, vertical=True, p=0.5):
super(RandomMirror, self).__init__()
self._batch_size = 0
self._p = p
self._vertical = vertical
self.register_buffer("_mirror_probs", torch.FloatTensor())
def update_probs(self, batch_size):
torch.ones([batch_size, 1], out=self._mirror_probs)
self._mirror_probs *= self._p
def forward(self, theta1, theta2):
batch_size = theta1.size(0)
if batch_size != self._batch_size:
self.update_probs(batch_size)
self._batch_size = batch_size
# apply random sign to a1 a2 a3 (these are the guys responsible for x)
sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0)
i = torch.ones_like(sign)
horizontal_mirror = torch.cat([sign, sign, sign, i, i, i], dim=1)
theta1 *= horizontal_mirror
theta2 *= horizontal_mirror
# apply random sign to a4 a5 a6 (these are the guys responsible for y)
if self._vertical:
sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0)
vertical_mirror = torch.cat([i, i, i, sign, sign, sign], dim=1)
theta1 *= vertical_mirror
theta2 *= vertical_mirror
return theta1, theta2
class RandomCrop(nn.Module):
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, crop):
super(RandomCrop, self).__init__()
self._crop_size = crop
self.register_buffer("_x", torch.LongTensor())
self.register_buffer("_y", torch.LongTensor())
def forward(self, im1, im2, flo):
batch_size, _, height, width = im1.size()
crop_height, crop_width = self._crop_size
# check whether there is anything to do
if any(self._size < 1):
return im1, im2, flo
# get starting positions
self._x.random_(0, width - crop_width)
self._y.random_(0, height - crop_height)
im1 = im1[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]
im2 = im2[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]
flo = flo[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]
class RandomAffineFlow(nn.Module):
def __init__(self, args, addnoise=True):
super(RandomAffineFlow, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror()
self._addnoise = addnoise
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
transformed = self._flow_interp2(new_flow, xq, yq)
return transformed
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo = example_dict["target1"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.2,
min_zoom=1.0, max_zoom=1.5,
min_squeeze=0.86, max_squeeze=1.16,
min_rotate=-0.2, max_rotate=0.2,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.015,
min_zoom=0.985, max_zoom=1.015,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.015, max_rotate=0.015,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo = self.transform_flow(flo, theta1, theta2)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
# construct updated dictionaries
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo
return example_dict
class RandomAffineFlowOcc(nn.Module):
def __init__(self, args, addnoise=True, crop=None):
super(RandomAffineFlowOcc, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror()
self._addnoise = addnoise
self._crop = crop
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
self.register_buffer("_x", torch.IntTensor(1))
self.register_buffer("_y", torch.IntTensor(1))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1. - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
transformed = self._flow_interp2(new_flow, xq, yq)
return transformed
def check_out_of_bound(self, flow, occ, batch_size):
_, _, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
xx = xx.expand(batch_size, -1, -1) + u
yy = yy.expand(batch_size, -1, -1) + v
out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)
occ = torch.clamp(out_of_bound + occ, 0, 1)
return occ
def random_crop(self, im1, im2, flo_f, flo_b, occ1, occ2):
_, _, height, width = im1.size()
crop_height, crop_width = self._crop
# get starting positions
self._x.random_(0, width - crop_width + 1)
self._y.random_(0, height - crop_height + 1)
str_x = int(self._x)
str_y = int(self._y)
end_x = int(self._x + crop_width)
end_y = int(self._y + crop_height)
im1 = im1[:, :, str_y:end_y, str_x:end_x]
im2 = im2[:, :, str_y:end_y, str_x:end_x]
flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]
flo_b = flo_b[:, :, str_y:end_y, str_x:end_x]
occ1 = occ1[:, :, str_y:end_y, str_x:end_x]
occ2 = occ2[:, :, str_y:end_y, str_x:end_x]
return im1, im2, flo_f, flo_b, occ1, occ2
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo_f = example_dict["target1"]
flo_b = example_dict["target2"]
occ1 = example_dict["target_occ1"]
occ2 = example_dict["target_occ2"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.2,
min_zoom=1.0, max_zoom=1.5,
min_squeeze=0.86, max_squeeze=1.16,
min_rotate=-0.2, max_rotate=0.2,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.015,
min_zoom=0.985, max_zoom=1.015,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.015, max_rotate=0.015,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo_f = self.transform_flow(flo_f, theta1, theta2)
flo_b = self.transform_flow(flo_b, theta2, theta1)
occ1 = self.transform_image(occ1, theta1)
occ2 = self.transform_image(occ2, theta2)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
if self._crop is not None:
im1, im2, flo_f, flo_b, occ1, occ2 = self.random_crop(im1, im2, flo_f, flo_b, occ1, occ2)
occ1 = self.check_out_of_bound(flo_f, occ1, batch_size)
occ2 = self.check_out_of_bound(flo_b, occ2, batch_size)
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo_f
example_dict["target2"] = flo_b
example_dict["target_occ1"] = occ1
example_dict["target_occ2"] = occ2
return example_dict
class RandomAffineFlowOccSintel(nn.Module):
def __init__(self, args, addnoise=True, crop=None):
super(RandomAffineFlowOccSintel, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror()
self._addnoise = addnoise
self._crop = crop
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
self.register_buffer("_x", torch.IntTensor(1))
self.register_buffer("_y", torch.IntTensor(1))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
transformed = self._flow_interp2(new_flow, xq, yq)
return transformed
def check_out_of_bound(self, flow, occ, batch_size):
_, _, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)
xx = torch.unsqueeze(xx, dim=0)
yy = torch.unsqueeze(yy, dim=0)
xx = xx.expand(batch_size, -1, -1) + u
yy = yy.expand(batch_size, -1, -1) + v
out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)
occ = torch.clamp(out_of_bound + occ, 0, 1)
return occ
def random_crop(self, im1, im2, flo_f, occ1):
_, _, height, width = im1.size()
crop_height, crop_width = self._crop
# get starting positions
self._x.random_(0, width - crop_width + 1)
self._y.random_(0, height - crop_height + 1)
str_x = int(self._x)
str_y = int(self._y)
end_x = int(self._x + crop_width)
end_y = int(self._y + crop_height)
im1 = im1[:, :, str_y:end_y, str_x:end_x]
im2 = im2[:, :, str_y:end_y, str_x:end_x]
flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]
occ1 = occ1[:, :, str_y:end_y, str_x:end_x]
return im1, im2, flo_f, occ1
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo_f = example_dict["target1"]
occ1 = example_dict["target_occ1"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.2,
min_zoom=1.0, max_zoom=1.5,
min_squeeze=0.86, max_squeeze=1.16,
min_rotate=-0.2, max_rotate=0.2,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.015,
min_zoom=0.985, max_zoom=1.015,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.015, max_rotate=0.015,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo_f = self.transform_flow(flo_f, theta1, theta2)
occ1 = self.transform_image(occ1, theta1)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
if self._crop is not None:
im1, im2, flo_f, occ1 = self.random_crop(im1, im2, flo_f, occ1)
occ1 = self.check_out_of_bound(flo_f, occ1, batch_size)
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo_f
example_dict["target_occ1"] = occ1
return example_dict
class RandomAffineFlowOccKITTI(nn.Module):
def __init__(self, args, addnoise=True, crop=None):
super(RandomAffineFlowOccKITTI, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2MaskBinary(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror(vertical=False)
self._addnoise = addnoise
self._crop = crop
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
self.register_buffer("_x", torch.IntTensor(1))
self.register_buffer("_y", torch.IntTensor(1))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2, valid_mask):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
# transformed = self._interp2(new_flow, xq, yq)
transformed, valid_mask = self._flow_interp2(new_flow, xq, yq, valid_mask)
return transformed, valid_mask
def check_out_of_bound(self, flow, occ, batch_size):
_, _, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
xx = xx.expand(batch_size, -1, -1) + u
yy = yy.expand(batch_size, -1, -1) + v
out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)
occ = torch.clamp(out_of_bound + occ, 0, 1)
return occ
def random_crop(self, im1, im2, flo_f, valid_mask):
_, _, height, width = im1.size()
crop_height, crop_width = self._crop
# get starting positions
self._x.random_(0, width - crop_width + 1)
self._y.random_(0, height - crop_height + 1)
str_x = int(self._x)
str_y = int(self._y)
end_x = int(self._x + crop_width)
end_y = int(self._y + crop_height)
im1 = im1[:, :, str_y:end_y, str_x:end_x]
im2 = im2[:, :, str_y:end_y, str_x:end_x]
flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]
valid_mask = valid_mask[:, :, str_y:end_y, str_x:end_x]
return im1, im2, flo_f, valid_mask
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo_f = example_dict["target1"]
valid_mask = example_dict["input_valid"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.04,
min_zoom=0.98, max_zoom=1.02,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.01, max_rotate=0.01,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.005,
min_zoom=0.99, max_zoom=1.01,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.01, max_rotate=0.01,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo_f, valid_mask = self.transform_flow(flo_f, theta1, theta2, valid_mask)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
if self._crop is not None:
im1, im2, flo_f, valid_mask = self.random_crop(im1, im2, flo_f, valid_mask)
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo_f
example_dict["input_valid"] = valid_mask
return example_dict
================================================
FILE: commandline.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import argparse
import colorama
import inspect
import os
import sys
import torch
import datasets
import losses
import models
import augmentations
import tools
import logger
import logging
import optim
def _get_type_from_arg(arg):
if isinstance(arg, bool):
return tools.str2bool
else:
return type(arg)
def _add_arguments_for_module(parser,
module,
name,
default_class,
add_class_argument=True, # whether to add class choice as argument
include_classes="*",
exclude_classes=[],
exclude_params=["self","args"],
param_defaults={}, # allows to overwrite any default param
forced_default_types={}, # allows to set types for known arguments
unknown_default_types={}): # allows to set types for unknown arguments
# -------------------------------------------------------------------------
# Determine possible choices from class names in module, possibly apply include/exclude filters
# -------------------------------------------------------------------------
module_dict = tools.module_classes_to_dict(
module, include_classes=include_classes, exclude_classes=exclude_classes)
# -------------------------------------------------------------------------
# Parse known arguments to determine choice for argument name
# -------------------------------------------------------------------------
if add_class_argument:
parser.add_argument(
"--%s" % name, type=str, default=default_class, choices=module_dict.keys())
known_args = parser.parse_known_args(sys.argv[1:])[0]
else:
# build a temporary parser, and do not add the class as argument
tmp_parser = argparse.ArgumentParser()
tmp_parser.add_argument(
"--%s" % name, type=str, default=default_class, choices=module_dict.keys())
known_args = tmp_parser.parse_known_args(sys.argv[1:])[0]
class_name = vars(known_args)[name]
# -------------------------------------------------------------------------
# If class is None, there is no point in trying to parse further arguments
# -------------------------------------------------------------------------
if class_name is None:
return
# -------------------------------------------------------------------------
# Get constructor of that argument choice
# -------------------------------------------------------------------------
class_constructor = module_dict[class_name]
# -------------------------------------------------------------------------
# Determine constructor argument names and defaults
# -------------------------------------------------------------------------
try:
argspec = inspect.getargspec(class_constructor.__init__)
argspec_defaults = argspec.defaults if argspec.defaults is not None else []
full_args = argspec.args
default_args_dict = dict(zip(argspec.args[-len(argspec_defaults):], argspec_defaults))
except TypeError:
print(argspec)
print(argspec.defaults)
raise ValueError("unknown_default_types should be adjusted for module: '%s.py'" % name)
# -------------------------------------------------------------------------
# Add sub_arguments
# -------------------------------------------------------------------------
for argname in full_args:
# ---------------------------------------------------------------------
# Skip
# ---------------------------------------------------------------------
if argname in exclude_params:
continue
# ---------------------------------------------------------------------
# Sub argument name
# ---------------------------------------------------------------------
sub_arg_name = "%s_%s" % (name, argname)
# ---------------------------------------------------------------------
# If a default argument is given, take that one
# ---------------------------------------------------------------------
if argname in param_defaults.keys():
parser.add_argument(
"--%s" % sub_arg_name,
type=_get_type_from_arg(param_defaults[argname]),
default=param_defaults[argname])
# ---------------------------------------------------------------------
# If a default parameter can be inferred from the module, pick that one
# ---------------------------------------------------------------------
elif argname in default_args_dict.keys():
# -----------------------------------------------------------------
# Check for forced default types
# -----------------------------------------------------------------
if argname in forced_default_types.keys():
argtype = forced_default_types[argname]
else:
argtype = _get_type_from_arg(default_args_dict[argname])
parser.add_argument(
"--%s" % sub_arg_name, type=argtype, default=default_args_dict[argname])
# ---------------------------------------------------------------------
# Take from the unkowns list
# ---------------------------------------------------------------------
elif argname in unknown_default_types.keys():
parser.add_argument("--%s" % sub_arg_name, type=unknown_default_types[argname])
else:
raise ValueError(
"Do not know how to handle argument '%s' for class '%s'" % (argname, name))
def _add_special_arguments(parser):
# -------------------------------------------------------------------------
# Known arguments so far
# -------------------------------------------------------------------------
known_args = vars(parser.parse_known_args(sys.argv[1:])[0])
# -------------------------------------------------------------------------
# Add special arguments for training
# -------------------------------------------------------------------------
training_loss = known_args["training_loss"]
if training_loss is not None:
parser.add_argument("--training_key", type=str, default="total_loss")
# -------------------------------------------------------------------------
# Add special arguments for validation
# -------------------------------------------------------------------------
validation_loss = known_args["validation_loss"]
if validation_loss is not None:
parser.add_argument("--validation_key", type=str, default="total_loss")
parser.add_argument("--validation_key_minimize", type=tools.str2bool, default=True)
# -------------------------------------------------------------------------
# Add special arguments for checkpoints
# -------------------------------------------------------------------------
checkpoint = known_args["checkpoint"]
if checkpoint is not None:
parser.add_argument(
"--checkpoint_mode", type=str, default="resume_from_latest",
choices=["resume_from_latest", "resume_from_best"])
parser.add_argument(
"--checkpoint_include_params", type=tools.str2list, default="[*]")
parser.add_argument(
"--checkpoint_exclude_params", type=tools.str2list, default="[]")
# -------------------------------------------------------------------------
# Add special arguments for optimizer groups
# -------------------------------------------------------------------------
parser.add_argument("--optimizer_group", action="append", type=tools.str2dict, default=None)
def _parse_arguments():
# -------------------------------------------------------------------------
# Argument parser and shortcut function to add arguments
# -------------------------------------------------------------------------
parser = argparse.ArgumentParser()
add = parser.add_argument
# -------------------------------------------------------------------------
# Standard arguments
# -------------------------------------------------------------------------
add("--batch_size", type=int, default=1)
add("--batch_size_val", type=int, default=1)
add("--checkpoint", type=tools.str2str_or_none, default=None)
add("--cuda", type=tools.str2bool, default=True)
add("--evaluation", type=tools.str2bool, default=False)
add("--name", default="run", type=str)
add("--num_workers", type=int, default=4)
add("--save", "-s", default="/tmp/work", type=str)
add("--seed", type=int, default=1)
add("--start_epoch", type=int, default=1)
add("--total_epochs", type=int, default=10)
add("--save_result_path_name", default="", type=str)
add("--save_result_img", type=tools.str2bool, default=False)
add("--save_result_occ", type=tools.str2bool, default=False)
add("--save_result_flo", type=tools.str2bool, default=False)
add("--save_result_png", type=tools.str2bool, default=False)
add("--save_result_bidirection", type=tools.str2bool, default=False)
add("--num_iters", type=int, default=1)
# -------------------------------------------------------------------------
# Arguments inferred from losses
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
losses,
name="training_loss",
default_class=None,
exclude_classes=["_*", "Variable"],
exclude_params=["self","args"])
_add_arguments_for_module(
parser,
losses,
name="validation_loss",
default_class=None,
exclude_classes=["_*", "Variable"],
exclude_params=["self","args"])
# -------------------------------------------------------------------------
# Arguments inferred from models
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
models,
name="model",
default_class="FlowNet1S",
exclude_classes=["_*", "Variable"],
exclude_params=["self","args"])
# -------------------------------------------------------------------------
# Arguments inferred from augmentations for training
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
augmentations,
name="training_augmentation",
default_class=None,
exclude_classes=["_*"],
exclude_params=["self","args"],
forced_default_types={"crop": tools.str2intlist})
# -------------------------------------------------------------------------
# Arguments inferred from augmentations for validation
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
augmentations,
name="validation_augmentation",
default_class=None,
exclude_classes=["_*"],
exclude_params=["self","args"])
# -------------------------------------------------------------------------
# Arguments inferred from datasets for training
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
datasets,
name="training_dataset",
default_class=None,
exclude_params=["self", "args", "is_cropped"],
exclude_classes=["_*"],
unknown_default_types={"root": str})
# -------------------------------------------------------------------------
# Arguments inferred from datasets for validation
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
datasets,
name="validation_dataset",
default_class=None,
exclude_params=["self", "args", "is_cropped"],
exclude_classes=["_*"],
unknown_default_types={"root": str})
# -------------------------------------------------------------------------
# Arguments inferred from PyTorch optimizers
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
optim,
name="optimizer",
default_class="Adam",
exclude_classes=["_*","Optimizer", "constructor"],
exclude_params=["self", "args", "params"],
forced_default_types={"lr": float,
"momentum": float,
"dampening": float,
"weight_decay": float,
"nesterov": tools.str2bool})
# -------------------------------------------------------------------------
# Arguments inferred from PyTorch lr schedulers
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
torch.optim.lr_scheduler,
name="lr_scheduler",
default_class=None,
exclude_classes=["_*","Optimizer"],
exclude_params=["self", "args", "optimizer"],
unknown_default_types={"T_max": int,
"lr_lambda": str,
"step_size": int,
"milestones": tools.str2intlist,
"gamma": float})
# -------------------------------------------------------------------------
# Special arguments
# -------------------------------------------------------------------------
_add_special_arguments(parser)
# -------------------------------------------------------------------------
# Parse arguments
# -------------------------------------------------------------------------
args = parser.parse_args()
# -------------------------------------------------------------------------
# Parse default arguments from a dummy commandline not specifying any args
# -------------------------------------------------------------------------
defaults = vars(parser.parse_known_args(['--dummy'])[0])
# -------------------------------------------------------------------------
# Consistency checks
# -------------------------------------------------------------------------
args.cuda = args.cuda and torch.cuda.is_available()
return args, defaults
def postprocess_args(args):
# ----------------------------------------------------------------------------
# Get appropriate class constructors from modules
# ----------------------------------------------------------------------------
args.model_class = tools.module_classes_to_dict(models)[args.model]
if args.optimizer is not None:
optimizer_classes = tools.module_classes_to_dict(optim)
args.optimizer_class = optimizer_classes[args.optimizer]
if args.training_loss is not None:
loss_classes = tools.module_classes_to_dict(losses)
args.training_loss_class = loss_classes[args.training_loss]
if args.validation_loss is not None:
loss_classes = tools.module_classes_to_dict(losses)
args.validation_loss_class = loss_classes[args.validation_loss]
if args.lr_scheduler is not None:
scheduler_classes = tools.module_classes_to_dict(torch.optim.lr_scheduler)
args.lr_scheduler_class = scheduler_classes[args.lr_scheduler]
if args.training_dataset is not None:
dataset_classes = tools.module_classes_to_dict(datasets)
args.training_dataset_class = dataset_classes[args.training_dataset]
if args.validation_dataset is not None:
dataset_classes = tools.module_classes_to_dict(datasets)
args.validation_dataset_class = dataset_classes[args.validation_dataset]
if args.training_augmentation is not None:
augmentation_classes = tools.module_classes_to_dict(augmentations)
args.training_augmentation_class = augmentation_classes[args.training_augmentation]
if args.validation_augmentation is not None:
augmentation_classes = tools.module_classes_to_dict(augmentations)
args.validation_augmentation_class = augmentation_classes[args.validation_augmentation]
return args
def setup_logging_and_parse_arguments(blocktitle):
# ----------------------------------------------------------------------------
# Get parse commandline and default arguments
# ----------------------------------------------------------------------------
args, defaults = _parse_arguments()
# ----------------------------------------------------------------------------
# Setup logbook before everything else
# ----------------------------------------------------------------------------
logger.configure_logging(os.path.join(args.save, 'logbook.txt'))
# ----------------------------------------------------------------------------
# Write arguments to file, as txt
# ----------------------------------------------------------------------------
tools.write_dictionary_to_file(
sorted(vars(args).items()),
filename=os.path.join(args.save, 'args.txt'))
# ----------------------------------------------------------------------------
# Log arguments
# ----------------------------------------------------------------------------
with logger.LoggingBlock(blocktitle, emph=True):
for argument, value in sorted(vars(args).items()):
reset = colorama.Style.RESET_ALL
color = reset if value == defaults[argument] else colorama.Fore.CYAN
logging.info('{}{}: {}{}'.format(color, argument, value, reset))
# ----------------------------------------------------------------------------
# Postprocess
# ----------------------------------------------------------------------------
args = postprocess_args(args)
return args
================================================
FILE: configuration.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import os
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
import logger
import tools
import logging
import shutil
import random
import fnmatch
# ---------------------------------------------------
# Class that contains both the network model and loss
# ---------------------------------------------------
class ModelAndLoss(nn.Module):
def __init__(self, args, model, training_loss, evaluation_loss=None):
super(ModelAndLoss, self).__init__()
self._model = model
self._training_loss = training_loss
self._evaluation_loss = evaluation_loss
@property
def training_loss(self):
return self._training_loss
@property
def evaluation_loss(self):
return self._evaluation_loss
@property
def model(self):
return self._model
def num_parameters(self):
return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
# -------------------------------------------------------------
# Note: We merge inputs and targets into a single dictionary !
# -------------------------------------------------------------
def forward(self, example_dict):
# -------------------------------------
# Run forward pass
# -------------------------------------
output_dict = self._model(example_dict)
# -------------------------------------
# Compute losses
# -------------------------------------
if self.training:
loss_dict = self._training_loss(output_dict, example_dict)
else:
loss_dict = self._evaluation_loss(output_dict, example_dict)
# -------------------------------------
# Return losses and outputs
# -------------------------------------
return loss_dict, output_dict
def configure_runtime_augmentations(args):
with logger.LoggingBlock("Runtime Augmentations", emph=True):
training_augmentation = None
validation_augmentation = None
# ----------------------------------------------------
# Training Augmentation
# ----------------------------------------------------
if args.training_augmentation is not None:
kwargs = tools.kwargs_from_args(args, "training_augmentation")
logging.info("training_augmentation: %s" % args.training_augmentation)
for param, default in sorted(kwargs.items()):
logging.info(" %s: %s" % (param, default))
kwargs["args"] = args
training_augmentation = tools.instance_from_kwargs(
args.training_augmentation_class, kwargs)
if args.cuda:
training_augmentation = training_augmentation.cuda()
else:
logging.info("training_augmentation: None")
# ----------------------------------------------------
# Validation Augmentation
# ----------------------------------------------------
if args.validation_augmentation is not None:
kwargs = tools.kwargs_from_args(args, "validation_augmentation")
logging.info("validation_augmentation: %s" % args.validation_augmentation)
for param, default in sorted(kwargs.items()):
logging.info(" %s: %s" % (param, default))
kwargs["args"] = args
validation_augmentation = tools.instance_from_kwargs(
args.validation_augmentation_class, kwargs)
if args.cuda:
validation_augmentation = validation_augmentation.cuda()
else:
logging.info("validation_augmentation: None")
return training_augmentation, validation_augmentation
def configure_model_and_loss(args):
# ----------------------------------------------------
# Dynamically load model and loss class with parameters
# passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments
# ----------------------------------------------------
with logger.LoggingBlock("Model and Loss", emph=True):
# ----------------------------------------------------
# Model
# ----------------------------------------------------
kwargs = tools.kwargs_from_args(args, "model")
kwargs["args"] = args
model = tools.instance_from_kwargs(args.model_class, kwargs)
# ----------------------------------------------------
# Training loss
# ----------------------------------------------------
training_loss = None
if args.training_loss is not None:
kwargs = tools.kwargs_from_args(args, "training_loss")
kwargs["args"] = args
training_loss = tools.instance_from_kwargs(args.training_loss_class, kwargs)
# ----------------------------------------------------
# Validation loss
# ----------------------------------------------------
validation_loss = None
if args.validation_loss is not None:
kwargs = tools.kwargs_from_args(args, "validation_loss")
kwargs["args"] = args
validation_loss = tools.instance_from_kwargs(args.validation_loss_class, kwargs)
# ----------------------------------------------------
# Model and loss
# ----------------------------------------------------
model_and_loss = ModelAndLoss(args, model, training_loss, validation_loss)
# -----------------------------------------------------------
# If Cuda, transfer model to Cuda and wrap with DataParallel.
# -----------------------------------------------------------
if args.cuda:
model_and_loss = model_and_loss.cuda()
# ---------------------------------------------------------------
# Report some network statistics
# ---------------------------------------------------------------
logging.info("Batch Size: %i" % args.batch_size)
logging.info("GPGPU: Cuda") if args.cuda else logging.info("GPGPU: off")
logging.info("Network: %s" % args.model)
logging.info("Number of parameters: %i" % tools.x2module(model_and_loss).num_parameters())
if training_loss is not None:
logging.info("Training Key: %s" % args.training_key)
logging.info("Training Loss: %s" % args.training_loss)
if validation_loss is not None:
logging.info("Validation Key: %s" % args.validation_key)
logging.info("Validation Loss: %s" % args.validation_loss)
return model_and_loss
def configure_random_seed(args):
with logger.LoggingBlock("Random Seeds", emph=True):
# python
seed = args.seed
random.seed(seed)
logging.info("Python seed: %i" % seed)
# numpy
seed += 1
np.random.seed(seed)
logging.info("Numpy seed: %i" % seed)
# torch
seed += 1
torch.manual_seed(seed)
logging.info("Torch CPU seed: %i" % seed)
# torch cuda
seed += 1
torch.cuda.manual_seed(seed)
logging.info("Torch CUDA seed: %i" % seed)
# --------------------------------------------------------------------------
# Checkpoint loader/saver.
# --------------------------------------------------------------------------
class CheckpointSaver:
def __init__(self,
prefix="checkpoint",
latest_postfix="_latest",
best_postfix="_best",
model_key="state_dict",
extension=".ckpt"):
self._prefix = prefix
self._model_key = model_key
self._latest_postfix = latest_postfix
self._best_postfix = best_postfix
self._extension = extension
# the purpose of rewriting the loading function is we sometimes want to
# initialize parameters in modules without knowing the dimensions at runtime
#
# This function here will resize these parameters to whatever size required.
def _load_state_dict_into_module(self, state_dict, module, strict=True):
own_state = module.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].resize_as_(param)
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
def restore(self, filename, model_and_loss, include_params="*", exclude_params=()):
# -----------------------------------------------------------------------------------------
# Make sure file exists
# -----------------------------------------------------------------------------------------
if not os.path.isfile(filename):
logging.info("Could not find checkpoint file '%s'!" % filename)
quit()
# -----------------------------------------------------------------------------------------
# Load checkpoint from file including the state_dict
# -----------------------------------------------------------------------------------------
checkpoint_with_state = torch.load(filename)
# -----------------------------------------------------------------------------------------
# Load filtered state dictionary
# -----------------------------------------------------------------------------------------
state_dict = checkpoint_with_state[self._model_key]
restore_keys = tools.filter_list_of_strings(
state_dict.keys(),
include=include_params,
exclude=exclude_params)
state_dict = {key: value for key, value in state_dict.items() if key in restore_keys}
self._load_state_dict_into_module(state_dict, model_and_loss)
# logging.info(" Restore keys:")
# for key in restore_keys:
# logging.info(" %s" % key)
# -----------------------------------------------------------------------------------------
# Get checkpoint statistics without the state dict
# -----------------------------------------------------------------------------------------
checkpoint_stats = {
key: value for key, value in checkpoint_with_state.items() if key != self._model_key
}
return checkpoint_stats, filename
def restore_latest(self, directory, model_and_loss, include_params="*", exclude_params=()):
latest_checkpoint_filename = os.path.join(
directory, self._prefix + self._latest_postfix + self._extension)
return self.restore(latest_checkpoint_filename, model_and_loss, include_params, exclude_params)
def restore_best(self, directory, model_and_loss, include_params="*", exclude_params=()):
best_checkpoint_filename = os.path.join(
directory, self._prefix + self._best_postfix + self._extension)
return self.restore(best_checkpoint_filename, model_and_loss, include_params, exclude_params)
def save_latest(self, directory, model_and_loss, stats_dict, store_as_best=False):
# -----------------------------------------------------------------------------------------
# Make sure directory exists
# -----------------------------------------------------------------------------------------
tools.ensure_dir(directory)
# -----------------------------------------------------------------------------------------
# Save
# -----------------------------------------------------------------------------------------
save_dict = dict(stats_dict)
save_dict[self._model_key] = model_and_loss.state_dict()
latest_checkpoint_filename = os.path.join(
directory, self._prefix + self._latest_postfix + self._extension)
latest_statistics_filename = os.path.join(
directory, self._prefix + self._latest_postfix + ".json")
torch.save(save_dict, latest_checkpoint_filename)
tools.write_json(data_dict=stats_dict, filename=latest_statistics_filename)
# -----------------------------------------------------------------------------------------
# Possibly store as best
# -----------------------------------------------------------------------------------------
if store_as_best:
best_checkpoint_filename = os.path.join(
directory, self._prefix + self._best_postfix + self._extension)
best_statistics_filename = os.path.join(
directory, self._prefix + self._best_postfix + ".json")
logging.info("Saved checkpoint as best model..")
shutil.copyfile(latest_checkpoint_filename, best_checkpoint_filename)
shutil.copyfile(latest_statistics_filename, best_statistics_filename)
def configure_checkpoint_saver(args, model_and_loss):
with logger.LoggingBlock("Checkpoint", emph=True):
checkpoint_saver = CheckpointSaver()
checkpoint_stats = None
if args.checkpoint is None:
logging.info("No checkpoint given.")
logging.info("Starting from scratch with random initialization.")
elif os.path.isfile(args.checkpoint):
checkpoint_stats, filename = checkpoint_saver.restore(
filename=args.checkpoint,
model_and_loss=model_and_loss,
include_params=args.checkpoint_include_params,
exclude_params=args.checkpoint_exclude_params)
elif os.path.isdir(args.checkpoint):
if args.checkpoint_mode in ["resume_from_best"]:
logging.info("Loading best checkpoint in %s" % args.checkpoint)
checkpoint_stats, filename = checkpoint_saver.restore_best(
directory=args.checkpoint,
model_and_loss=model_and_loss,
include_params=args.checkpoint_include_params,
exclude_params=args.checkpoint_exclude_params)
elif args.checkpoint_mode in ["resume_from_latest"]:
logging.info("Loading latest checkpoint in %s" % args.checkpoint)
checkpoint_stats, filename = checkpoint_saver.restore_latest(
directory=args.checkpoint,
model_and_loss=model_and_loss,
include_params=args.checkpoint_include_params,
exclude_params=args.checkpoint_exclude_params)
else:
logging.info("Unknown checkpoint_restore '%s' given!" % args.checkpoint_restore)
quit()
else:
logging.info("Could not find checkpoint file or directory '%s'" % args.checkpoint)
quit()
return checkpoint_saver, checkpoint_stats
# -------------------------------------------------------------------------------------------------
# Configure data loading
# -------------------------------------------------------------------------------------------------
def configure_data_loaders(args):
with logger.LoggingBlock("Datasets", emph=True):
def _sizes_to_str(value):
if np.isscalar(value):
return '[1L]'
else:
return ' '.join([str([d for d in value.size()])])
def _log_statistics(dataset, prefix, name):
with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)):
example_dict = dataset[0] # get sizes from first dataset example
for key, value in sorted(example_dict.items()):
if key in ["index", "basename"]: # no need to display these
continue
if isinstance(value, str):
logging.info("{}: {}".format(key, value))
else:
logging.info("%s: %s" % (key, _sizes_to_str(value)))
logging.info("num_examples: %i" % len(dataset))
# -----------------------------------------------------------------------------------------
# GPU parameters -- turning off pin_memory? for resolving the deadlock?
# -----------------------------------------------------------------------------------------
gpuargs = {"num_workers": args.num_workers, "pin_memory": False} if args.cuda else {}
train_loader = None
validation_loader = None
inference_loader = None
# -----------------------------------------------------------------------------------------
# Training dataset
# -----------------------------------------------------------------------------------------
if args.training_dataset is not None:
# ----------------------------------------------
# Figure out training_dataset arguments
# ----------------------------------------------
kwargs = tools.kwargs_from_args(args, "training_dataset")
kwargs["is_cropped"] = True
kwargs["args"] = args
# ----------------------------------------------
# Create training dataset
# ----------------------------------------------
train_dataset = tools.instance_from_kwargs(args.training_dataset_class, kwargs)
# ----------------------------------------------
# Create training loader
# ----------------------------------------------
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
**gpuargs)
_log_statistics(train_dataset, prefix="Training", name=args.training_dataset)
# -----------------------------------------------------------------------------------------
# Validation dataset
# -----------------------------------------------------------------------------------------
if args.validation_dataset is not None:
# ----------------------------------------------
# Figure out validation_dataset arguments
# ----------------------------------------------
kwargs = tools.kwargs_from_args(args, "validation_dataset")
kwargs["is_cropped"] = True
kwargs["args"] = args
# ----------------------------------------------
# Create validation dataset
# ----------------------------------------------
validation_dataset = tools.instance_from_kwargs(args.validation_dataset_class, kwargs)
# ----------------------------------------------
# Create validation loader
# ----------------------------------------------
validation_loader = DataLoader(
validation_dataset,
batch_size=args.batch_size_val,
shuffle=False,
drop_last=False,
**gpuargs)
_log_statistics(validation_dataset, prefix="Validation", name=args.validation_dataset)
return train_loader, validation_loader, inference_loader
# ------------------------------------------------------------
# Generator for trainable parameters by pattern matching
# ------------------------------------------------------------
def _print_trainable_params(model_and_loss, match="*"):
sum = 0
for name, p in model_and_loss.named_parameters():
if fnmatch.fnmatch(name, match):
if p.requires_grad:
logging.info(name)
logging.info(str(p.numel()))
print(name)
print(p.numel())
sum += p.numel()
logging.info(str(sum))
def _generate_trainable_params(model_and_loss, match="*"):
for name, p in model_and_loss.named_parameters():
if fnmatch.fnmatch(name, match):
if p.requires_grad:
yield p
def _param_names_and_trainable_generator(model_and_loss, match="*"):
names = []
for name, p in model_and_loss.named_parameters():
if fnmatch.fnmatch(name, match):
if p.requires_grad:
names.append(name)
return names, _generate_trainable_params(model_and_loss, match=match)
# -------------------------------------------------------------------------------------------------
# Build optimizer:
# -------------------------------------------------------------------------------------------------
def configure_optimizer(args, model_and_loss):
optimizer = None
with logger.LoggingBlock("Optimizer", emph=True):
if args.optimizer is not None:
if model_and_loss.num_parameters() == 0:
logging.info("No trainable parameters detected.")
logging.info("Setting optimizer to None.")
else:
logging.info(args.optimizer)
# -------------------------------------------
# Figure out all optimizer arguments
# -------------------------------------------
all_kwargs = tools.kwargs_from_args(args, "optimizer")
# -------------------------------------------
# Get the split of param groups
# -------------------------------------------
kwargs_without_groups = {
key: value for key,value in all_kwargs.items() if key != "group"
}
param_groups = all_kwargs["group"]
# ----------------------------------------------------------------------
# Print arguments (without groups)
# ----------------------------------------------------------------------
for param, default in sorted(kwargs_without_groups.items()):
logging.info("%s: %s" % (param, default))
# ----------------------------------------------------------------------
# Construct actual optimizer params
# ----------------------------------------------------------------------
kwargs = dict(kwargs_without_groups)
if param_groups is None:
# ---------------------------------------------------------
# Add all trainable parameters if there is no param groups
# ---------------------------------------------------------
all_trainable_parameters = _generate_trainable_params(model_and_loss)
kwargs["params"] = all_trainable_parameters
else:
# -------------------------------------------
# Add list of parameter groups instead
# -------------------------------------------
trainable_parameter_groups = []
dnames, dparams = _param_names_and_trainable_generator(model_and_loss)
dnames = set(dnames)
dparams = set(list(dparams))
with logger.LoggingBlock("parameter_groups:"):
for group in param_groups:
# log group settings
group_match = group["params"]
group_args = {
key: value for key, value in group.items() if key != "params"
}
with logger.LoggingBlock("%s: %s" % (group_match, group_args)):
# retrieve parameters by matching name
gnames, gparams = _param_names_and_trainable_generator(
model_and_loss, match=group_match)
# log all names affected
for n in sorted(gnames):
logging.info(n)
# set generator for group
group_args["params"] = gparams
# append parameter group
trainable_parameter_groups.append(group_args)
# update remaining trainable parameters
dnames -= set(gnames)
dparams -= set(list(gparams))
# append default parameter group
trainable_parameter_groups.append({"params": list(dparams)})
# and log its parameter names
with logger.LoggingBlock("default:"):
for dname in sorted(dnames):
logging.info(dname)
# set params in optimizer kwargs
kwargs["params"] = trainable_parameter_groups
# -------------------------------------------
# Create optimizer instance
# -------------------------------------------
optimizer = tools.instance_from_kwargs(args.optimizer_class, kwargs)
return optimizer
# -------------------------------------------------------------------------------------------------
# Configure learning rate scheduler
# -------------------------------------------------------------------------------------------------
def configure_lr_scheduler(args, optimizer):
lr_scheduler = None
with logger.LoggingBlock("Learning Rate Scheduler", emph=True):
logging.info("class: %s" % args.lr_scheduler)
if args.lr_scheduler is not None:
# ----------------------------------------------
# Figure out lr_scheduler arguments
# ----------------------------------------------
kwargs = tools.kwargs_from_args(args, "lr_scheduler")
# -------------------------------------------
# Print arguments
# -------------------------------------------
for param, default in sorted(kwargs.items()):
logging.info("%s: %s" % (param, default))
# -------------------------------------------
# Add optimizer
# -------------------------------------------
kwargs["optimizer"] = optimizer
# -------------------------------------------
# Create lr_scheduler instance
# -------------------------------------------
lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class, kwargs)
return lr_scheduler
================================================
FILE: datasets/__init__.py
================================================
from . import flyingchairs
from . import flyingchairsOcc
from . import sintel
from . import flyingThings3D
from . import kitti_combined
from . import sintel
## FlyingChairs
FlyingChairsTrain = flyingchairs.FlyingChairsTrain
FlyingChairsValid = flyingchairs.FlyingChairsValid
FlyingChairsFull = flyingchairs.FlyingChairsFull
## Our custom FlyingChairs + Occ
FlyingChairsOccTrain = flyingchairsOcc.FlyingChairsOccTrain
FlyingChairsOccValid = flyingchairsOcc.FlyingChairsOccValid
FlyingChairsOccFull = flyingchairsOcc.FlyingChairsOccFull
## FlyingThings3D_subset
FlyingThings3dFinalTrain = flyingThings3D.FlyingThings3dFinalTrain
FlyingThings3dFinalTest = flyingThings3D.FlyingThings3dFinalTest
FlyingThings3dCleanTrain = flyingThings3D.FlyingThings3dCleanTrain
FlyingThings3dCleanTest = flyingThings3D.FlyingThings3dCleanTest
## Sintel
SintelTestClean = sintel.SintelTestClean
SintelTestFinal = sintel.SintelTestFinal
SintelTrainingCombFull = sintel.SintelTrainingCombFull
SintelTrainingCombTrain = sintel.SintelTrainingCombTrain
SintelTrainingCombValid = sintel.SintelTrainingCombValid
SintelTrainingCleanFull = sintel.SintelTrainingCleanFull
SintelTrainingCleanTrain = sintel.SintelTrainingCleanTrain
SintelTrainingCleanValid = sintel.SintelTrainingCleanValid
SintelTrainingFinalFull = sintel.SintelTrainingFinalFull
SintelTrainingFinalTrain = sintel.SintelTrainingFinalTrain
SintelTrainingFinalValid = sintel.SintelTrainingFinalValid
## KITTI Optical Flow 2012 + 2015
KittiCombTrain = kitti_combined.KittiCombTrain
KittiCombVal = kitti_combined.KittiCombVal
KittiCombFull = kitti_combined.KittiCombFull
KittiComb2012Train = kitti_combined.KittiComb2012Train
KittiComb2012Val = kitti_combined.KittiComb2012Val
KittiComb2012Full = kitti_combined.KittiComb2012Full
KittiComb2012Test = kitti_combined.KittiComb2012Test
KittiComb2015Train = kitti_combined.KittiComb2015Train
KittiComb2015Val = kitti_combined.KittiComb2015Val
KittiComb2015Full = kitti_combined.KittiComb2015Full
KittiComb2015Test = kitti_combined.KittiComb2015Test
================================================
FILE: datasets/common.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import torch
import numpy as np
import skimage.io as io
def numpy2torch(array):
assert(isinstance(array, np.ndarray))
if array.ndim == 3:
array = np.transpose(array, (2, 0, 1))
else:
array = np.expand_dims(array, axis=0)
return torch.from_numpy(array.copy()).float()
def read_flo_as_float32(filename):
with open(filename, 'rb') as file:
magic = np.fromfile(file, np.float32, count=1)
assert(202021.25 == magic), "Magic number incorrect. Invalid .flo file"
w = np.fromfile(file, np.int32, count=1)[0]
h = np.fromfile(file, np.int32, count=1)[0]
data = np.fromfile(file, np.float32, count=2*h*w)
data2D = np.resize(data, (h, w, 2))
return data2D
def read_occ_image_as_float32(filename):
occ = io.imread(filename).astype(np.float32) / np.float32(255.0)
if occ.ndim == 3:
occ = occ[:, :, 0]
return occ
def read_image_as_float32(filename):
return io.imread(filename).astype(np.float32) / np.float32(255.0)
def read_image_as_byte(filename):
return io.imread(filename)
================================================
FILE: datasets/flyingThings3D.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
import numpy as np
def fillingInNaN(flow):
h, w, c = flow.shape
indices = np.argwhere(np.isnan(flow))
neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]]
for ii, idx in enumerate(indices):
sum_sample = 0
count = 0
for jj in range(0, len(neighbors) - 1):
hh = idx[0] + neighbors[jj][0]
ww = idx[1] + neighbors[jj][1]
if hh < 0 or hh >= h:
continue
if ww < 0 or ww >= w:
continue
sample_flow = flow[hh, ww, idx[2]]
if np.isnan(sample_flow):
continue
sum_sample += sample_flow
count += 1
if count is 0:
print('FATAL ERROR: no sample')
flow[idx[0], idx[1], idx[2]] = sum_sample / count
return flow
class FlyingThings3d(data.Dataset):
def __init__(self,
args,
images_root,
flow_root,
occ_root,
photometric_augmentations=False):
self._args = args
if not os.path.isdir(images_root):
raise ValueError("Image directory '%s' not found!")
if flow_root is not None and not os.path.isdir(flow_root):
raise ValueError("Flow directory '%s' not found!")
if occ_root is not None and not os.path.isdir(occ_root):
raise ValueError("Occ directory '%s' not found!")
if flow_root is not None:
flow_f_filenames = sorted(glob(os.path.join(flow_root, "into_future/*.flo")))
flow_b_filenames = sorted(glob(os.path.join(flow_root, "into_past/*.flo")))
if occ_root is not None:
occ1_filenames = sorted(glob(os.path.join(occ_root, "into_future/*.png")))
occ2_filenames = sorted(glob(os.path.join(occ_root, "into_past/*.png")))
all_img_filenames = sorted(glob(os.path.join(images_root, "*.png")))
self._image_list = []
self._flow_list = [] if flow_root is not None else None
self._occ_list = [] if occ_root is not None else None
assert len(all_img_filenames) != 0
assert len(flow_f_filenames) != 0
assert len(flow_b_filenames) != 0
assert len(occ1_filenames) != 0
assert len(occ2_filenames) != 0
## path definition
path_flow_f = os.path.join(flow_root, "into_future")
path_flow_b = os.path.join(flow_root, "into_past")
path_occ_f = os.path.join(occ_root, "into_future")
path_occ_b = os.path.join(occ_root, "into_past")
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
for ii in range(0, len(flow_f_filenames)):
flo_f = flow_f_filenames[ii]
idx_f = os.path.splitext(os.path.basename(flo_f))[0]
idx_b = str(int(idx_f) + 1).zfill(7)
flo_b = os.path.join(path_flow_b, idx_b + ".flo")
im1 = os.path.join(images_root, idx_f + ".png")
im2 = os.path.join(images_root, idx_b + ".png")
occ1 = os.path.join(path_occ_f, idx_f + ".png")
occ2 = os.path.join(path_occ_b, idx_b + ".png")
if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile(
im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2):
continue
self._image_list += [[im1, im2]]
self._flow_list += [[flo_f, flo_b]]
self._occ_list += [[occ1, occ2]]
self._size = len(self._image_list)
assert len(self._image_list) == len(self._flow_list)
assert len(self._occ_list) == len(self._flow_list)
assert len(self._image_list) != 0
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_f_filename = self._flow_list[index][0]
flo_b_filename = self._flow_list[index][1]
occ1_filename = self._occ_list[index][0]
occ2_filename = self._occ_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_f_np0 = common.read_flo_as_float32(flo_f_filename)
flo_b_np0 = common.read_flo_as_float32(flo_b_filename)
occ1_np0 = common.read_occ_image_as_float32(occ1_filename)
occ2_np0 = common.read_occ_image_as_float32(occ2_filename)
# temp - check isnan
if np.any(np.isnan(flo_f_np0)):
flo_f_np0 = fillingInNaN(flo_f_np0)
if np.any(np.isnan(flo_b_np0)):
flo_b_np0 = fillingInNaN(flo_b_np0)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo_f = common.numpy2torch(flo_f_np0)
flo_b = common.numpy2torch(flo_b_np0)
# convert occ to FloatTensor
occ1 = common.numpy2torch(occ1_np0)
occ2 = common.numpy2torch(occ2_np0)
# example filename
basename = os.path.basename(im1_filename)[:5]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo_f,
"target2": flo_b,
"target_occ1": occ1,
"target_occ2": occ2,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class FlyingThings3dFinalTrain(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=True):
images_root = os.path.join(root, "frames_finalpass")
flow_root = os.path.join(root, "optical_flow")
occ_root = os.path.join(root, "occlusion")
super(FlyingThings3dFinalTrain, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
class FlyingThings3dFinalTest(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=False):
images_root = os.path.join(root, "frames_finalpass")
flow_root = os.path.join(root, "optical_flow")
occ_root = os.path.join(root, "occlusion")
super(FlyingThings3dFinalTest, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
class FlyingThings3dCleanTrain(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=True):
images_root = os.path.join(root, "train", "image_clean", "left")
flow_root = os.path.join(root, "train", "flow", "left")
occ_root = os.path.join(root, "train", "flow_occlusions", "left")
super(FlyingThings3dCleanTrain, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
class FlyingThings3dCleanTest(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=False):
images_root = os.path.join(root, "frames_cleanpass")
flow_root = os.path.join(root, "optical_flow")
occ_root = os.path.join(root, "occlusion")
super(FlyingThings3dCleanTest, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
================================================
FILE: datasets/flyingchairs.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
VALIDATE_INDICES = [
5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,
152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,
337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,
531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,
810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,
1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,
1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,
1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,
1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,
1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,
2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,
2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,
2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,
2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,
2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,
3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,
3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,
3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,
3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,
3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,
3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,
4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,
4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,
4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,
4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,
4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,
4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,
5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,
5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,
5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,
5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,
5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,
5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,
6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,
6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,
6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,
6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,
6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,
6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,
7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,
7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,
7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,
7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,
7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,
8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,
8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,
8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,
8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,
8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,
9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,
9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,
9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,
9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,
9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,
9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,
10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,
10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,
10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,
10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,
10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,
12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,
13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,
15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,
17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,
21802, 21803, 21806, 22584, 22857, 22858, 22866]
class FlyingChairs(data.Dataset):
def __init__(self,
args,
root,
photometric_augmentations=False,
dstype="train"):
self._args = args
# -------------------------------------------------------------
# filenames for all input images and target flows
# -------------------------------------------------------------
image_filenames = sorted( glob( os.path.join(root, "*.ppm")) )
flow_filenames = sorted( glob( os.path.join(root, "*.flo")) )
assert (len(image_filenames)/2 == len(flow_filenames))
num_flows = len(flow_filenames)
# -------------------------------------------------------------
# Remove invalid validation indices
# -------------------------------------------------------------
validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]
# ----------------------------------------------------------
# Construct list of indices for training/validation
# ----------------------------------------------------------
list_of_indices = None
if dstype == "train":
list_of_indices = [x for x in range(num_flows) if x not in validate_indices]
elif dstype == "valid":
list_of_indices = validate_indices
elif dstype == "full":
list_of_indices = range(num_flows)
else:
raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
for i in list_of_indices:
flo = flow_filenames[i]
im1 = image_filenames[2*i]
im2 = image_filenames[2*i + 1]
self._image_list += [ [ im1, im2 ] ]
self._flow_list += [ flo ]
self._size = len(self._image_list)
assert len(self._image_list) == len(self._flow_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_filename = self._flow_list[index]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_np0 = common.read_flo_as_float32(flo_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo = common.numpy2torch(flo_np0)
# target_occ: initialized by zero (not used)
target_occ = common.numpy2torch(common.read_occ_image_as_float32(im1_filename)) * 0
# example filename
basename = os.path.basename(im1_filename)[:5]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo,
"target_occ1": target_occ,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class FlyingChairsTrain(FlyingChairs):
def __init__(self,
args,
root,
photometric_augmentations=True):
super(FlyingChairsTrain, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="train")
class FlyingChairsValid(FlyingChairs):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsValid, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="valid")
class FlyingChairsFull(FlyingChairs):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsFull, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="full")
================================================
FILE: datasets/flyingchairsOcc.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
VALIDATE_INDICES = [
5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,
152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,
337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,
531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,
810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,
1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,
1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,
1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,
1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,
1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,
2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,
2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,
2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,
2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,
2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,
3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,
3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,
3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,
3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,
3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,
3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,
4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,
4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,
4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,
4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,
4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,
4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,
5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,
5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,
5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,
5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,
5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,
5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,
6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,
6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,
6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,
6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,
6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,
6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,
7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,
7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,
7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,
7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,
7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,
8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,
8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,
8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,
8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,
8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,
9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,
9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,
9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,
9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,
9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,
9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,
10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,
10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,
10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,
10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,
10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,
12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,
13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,
15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,
17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,
21802, 21803, 21806, 22584, 22857, 22858, 22866]
class FlyingChairsOcc(data.Dataset):
def __init__(self,
args,
root,
photometric_augmentations=False,
dstype="train"):
self._args = args
# -------------------------------------------------------------
# filenames for all input images and target flows
# -------------------------------------------------------------
image1_filenames = sorted(glob(os.path.join(root, "*_img1.png")))
image2_filenames = sorted(glob(os.path.join(root, "*_img2.png")))
occ1_filenames = sorted(glob(os.path.join(root, "*_occ1.png")))
occ2_filenames = sorted(glob(os.path.join(root, "*_occ2.png")))
flow_f_filenames = sorted(glob(os.path.join(root, "*_flow.flo")))
flow_b_filenames = sorted(glob(os.path.join(root, "*_flow_b.flo")))
assert (len(image1_filenames) == len(image2_filenames))
assert (len(image2_filenames) == len(occ1_filenames))
assert (len(occ1_filenames) == len(occ2_filenames))
assert (len(occ2_filenames) == len(flow_f_filenames))
assert (len(flow_f_filenames) == len(flow_b_filenames))
num_flows = len(flow_f_filenames)
# -------------------------------------------------------------
# Remove invalid validation indices
# -------------------------------------------------------------
validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]
# ----------------------------------------------------------
# Construct list of indices for training/validation
# ----------------------------------------------------------
list_of_indices = None
if dstype == "train":
list_of_indices = [x for x in range(num_flows) if x not in validate_indices]
elif dstype == "valid":
list_of_indices = validate_indices
elif dstype == "full":
list_of_indices = range(num_flows)
else:
raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
self._occ_list = []
for i in list_of_indices:
flo_f = flow_f_filenames[i]
flo_b = flow_b_filenames[i]
im1 = image1_filenames[i]
im2 = image2_filenames[i]
occ1 = occ1_filenames[i]
occ2 = occ2_filenames[i]
self._image_list += [[im1, im2]]
self._flow_list += [[flo_f, flo_b]]
self._occ_list += [[occ1, occ2]]
self._size = len(self._image_list)
assert len(self._image_list) == len(self._flow_list)
assert len(self._occ_list) == len(self._flow_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_f_filename = self._flow_list[index][0]
flo_b_filename = self._flow_list[index][1]
occ1_filename = self._occ_list[index][0]
occ2_filename = self._occ_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_f_np0 = common.read_flo_as_float32(flo_f_filename)
flo_b_np0 = common.read_flo_as_float32(flo_b_filename)
occ1_np0 = common.read_occ_image_as_float32(occ1_filename)
occ2_np0 = common.read_occ_image_as_float32(occ2_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo_f = common.numpy2torch(flo_f_np0)
flo_b = common.numpy2torch(flo_b_np0)
# convert occ to FloatTensor
occ1 = common.numpy2torch(occ1_np0)
occ2 = common.numpy2torch(occ2_np0)
# example filename
basename = os.path.basename(im1_filename)[:5]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo_f,
"target2": flo_b,
"target_occ1": occ1,
"target_occ2": occ2,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class FlyingChairsOccTrain(FlyingChairsOcc):
def __init__(self,
args,
root,
photometric_augmentations=True):
super(FlyingChairsOccTrain, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="train")
class FlyingChairsOccValid(FlyingChairsOcc):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsOccValid, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="valid")
class FlyingChairsOccFull(FlyingChairsOcc):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsOccFull, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="full")
================================================
FILE: datasets/kitti_combined.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
import numpy as np
import png
VALIDATE_INDICES_2015 = [10, 11, 12, 25, 26, 30, 31, 40, 41, 42, 46, 52, 53, 72, 73, 74, 75, 76, 80, 81, 85, 86, 95, 96, 97, 98, 104, 116, 117, 120, 121, 126, 127, 153, 172, 175, 183, 184, 190, 199]
VALIDATE_INDICES_2012 = [0, 12, 15, 16, 17, 18, 24, 30, 38, 39, 42, 50, 54, 59, 60, 61, 77, 78, 81, 89, 97, 101, 107, 121, 124, 142, 145, 146, 152, 154, 155, 158, 159, 160, 164, 182, 183, 184, 190]
def read_png_flow(flow_file):
flow_object = png.Reader(filename=flow_file)
flow_direct = flow_object.asDirect()
flow_data = list(flow_direct[2])
(w, h) = flow_direct[3]['size']
flow = np.zeros((h, w, 3), dtype=np.float64)
for i in range(len(flow_data)):
flow[i, :, 0] = flow_data[i][0::3]
flow[i, :, 1] = flow_data[i][1::3]
flow[i, :, 2] = flow_data[i][2::3]
invalid_idx = (flow[:, :, 2] == 0)
flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
flow[invalid_idx, 0] = 0
flow[invalid_idx, 1] = 0
return flow[:, :, 0:2], (1 - invalid_idx * 1)[:, :, None]
def kitti_random_crop(im1, im2, flo_f, valid_mask, crop_height=370, crop_width=1224):
height, width, _ = im1.shape
# get starting positions
x = np.random.uniform(0, width - crop_width + 1)
y = np.random.uniform(0, height - crop_height + 1)
str_x = int(x)
str_y = int(y)
end_x = int(x + crop_width)
end_y = int(y + crop_height)
im1 = im1[str_y:end_y, str_x:end_x, :]
im2 = im2[str_y:end_y, str_x:end_x, :]
flo_f = flo_f[str_y:end_y, str_x:end_x, :]
valid_mask = valid_mask[str_y:end_y, str_x:end_x, :]
return im1, im2, flo_f, valid_mask
class Kitti_comb_test(data.Dataset):
def __init__(self,
args,
images_root_2015=None,
images_root_2012=None,
photometric_augmentations=False,
preprocessing_crop=True):
self._args = args
self.preprocessing_crop = preprocessing_crop
list_of_indices_2012 = []
list_of_indices_2015 = []
# ----------------------------------------------------------
# KITTI 2015
# ----------------------------------------------------------
if images_root_2015 is not None:
if not os.path.isdir(images_root_2015):
raise ValueError("Image directory not found! {}".format(images_root_2015))
all_img1_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_10.png")))
all_img2_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_11.png")))
assert len(all_img1_2015_filenames) != 0
assert len(all_img2_2015_filenames) == len(all_img1_2015_filenames)
list_of_indices_2015 = range(len(all_img1_2015_filenames))
# ----------------------------------------------------------
# KITTI 2012
# ----------------------------------------------------------
if images_root_2012 is not None:
if not os.path.isdir(images_root_2012):
raise ValueError("Image directory not found! {}".format(images_root_2012))
all_img1_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_10.png")))
all_img2_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_11.png")))
assert len(all_img1_2012_filenames) != 0
assert len(all_img2_2012_filenames) == len(all_img1_2012_filenames)
list_of_indices_2012 = range(len(all_img1_2012_filenames))
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
for ii in list_of_indices_2015:
im1 = all_img1_2015_filenames[ii]
im2 = all_img2_2015_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
for ii in list_of_indices_2012:
im1 = all_img1_2012_filenames[ii]
im2 = all_img2_2012_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
self._size = len(self._image_list)
assert len(self._image_list) != 0
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# example filename
basename = os.path.basename(im1_filename)[:6]
example_dict = {
"input1": im1,
"input2": im2,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class Kitti_comb(data.Dataset):
def __init__(self,
args,
images_root_2015=None,
flow_root_2015=None,
images_root_2012=None,
flow_root_2012=None,
photometric_augmentations=False,
preprocessing_crop=True,
dstype="full"):
self._args = args
self.preprocessing_crop = preprocessing_crop
list_of_indices_2012 = []
list_of_indices_2015 = []
# ----------------------------------------------------------
# KITTI 2015
# ----------------------------------------------------------
if images_root_2015 is not None and flow_root_2015 is not None:
if not os.path.isdir(images_root_2015):
raise ValueError("Image directory not found! {}".format(images_root_2015))
if not os.path.isdir(flow_root_2015):
raise ValueError("Flow directory not found! {}".format(flow_root_2015))
all_img1_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_10.png")))
all_img2_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_11.png")))
flow_f_2015_filenames = sorted(glob(os.path.join(flow_root_2015, "*_10.png")))
assert len(all_img1_2015_filenames) != 0
assert len(all_img2_2015_filenames) == len(all_img1_2015_filenames)
assert len(flow_f_2015_filenames) == len(all_img1_2015_filenames)
num_flows_2015 = len(flow_f_2015_filenames)
validate_indices_2015 = [x for x in VALIDATE_INDICES_2015 if x in range(num_flows_2015)]
if dstype == "train":
list_of_indices_2015 = [x for x in range(num_flows_2015) if x not in validate_indices_2015]
elif dstype == "valid":
list_of_indices_2015 = validate_indices_2015
elif dstype == "full":
list_of_indices_2015 = range(len(all_img1_2015_filenames))
else:
raise ValueError("KITTI 2015: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# KITTI 2012
# ----------------------------------------------------------
if images_root_2012 is not None:
if not os.path.isdir(images_root_2012):
raise ValueError("Image directory '%s' not found!")
if not os.path.isdir(flow_root_2012):
raise ValueError("Flow directory '%s' not found!")
all_img1_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_10.png")))
all_img2_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_11.png")))
flow_f_2012_filenames = sorted(glob(os.path.join(flow_root_2012, "*_10.png")))
assert len(all_img1_2012_filenames) != 0
assert len(all_img2_2012_filenames) == len(all_img1_2012_filenames)
assert len(flow_f_2012_filenames) == len(all_img1_2012_filenames)
num_flows_2012 = len(flow_f_2012_filenames)
validate_indices_2012 = [x for x in VALIDATE_INDICES_2012 if x in range(num_flows_2012)]
if dstype == "train":
list_of_indices_2012 = [x for x in range(num_flows_2012) if x not in validate_indices_2012]
elif dstype == "valid":
list_of_indices_2012 = validate_indices_2012
elif dstype == "full":
list_of_indices_2012 = range(len(all_img1_2012_filenames))
else:
raise ValueError("KITTI 2012: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
for ii in list_of_indices_2015:
im1 = all_img1_2015_filenames[ii]
im2 = all_img2_2015_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
if dstype is not "test":
flo_f = flow_f_2015_filenames[ii]
idx_f = os.path.splitext(os.path.basename(flo_f))[0][:-3]
assert idx1 == idx_f
if not os.path.isfile(flo_f):
continue
self._flow_list += [[flo_f]]
for ii in list_of_indices_2012:
im1 = all_img1_2012_filenames[ii]
im2 = all_img2_2012_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
if dstype is not "test":
flo_f = flow_f_2012_filenames[ii]
idx_f = os.path.splitext(os.path.basename(flo_f))[0][:-3]
assert idx1 == idx_f
if not os.path.isfile(flo_f):
continue
self._flow_list += [[flo_f]]
self._size = len(self._image_list)
assert len(self._image_list) != 0
if dstype is not "test":
assert len(self._image_list) == len(self._flow_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_f_filename = self._flow_list[index][0]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_f_np0, valid_mask = read_png_flow(flo_f_filename)
if self.preprocessing_crop:
im1_np0, im2_np0, flo_f_np0, valid_mask = kitti_random_crop(im1_np0, im2_np0, flo_f_np0, valid_mask)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo_f = common.numpy2torch(flo_f_np0)
valid_mask_f = common.numpy2torch(valid_mask)
# example filename
basename = os.path.basename(im1_filename)[:6]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo_f,
"target2": flo_f,
"index": index,
"basename": basename,
"input_valid": valid_mask_f
}
return example_dict
def __len__(self):
return self._size
class KittiCombTrain(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiCombTrain, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="train")
class KittiCombVal(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiCombVal, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="valid")
class KittiCombFull(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiCombFull, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="full")
class KittiComb2015Train(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
super(KittiComb2015Train, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="train")
class KittiComb2015Val(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
super(KittiComb2015Val, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="valid")
class KittiComb2015Full(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
super(KittiComb2015Full, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="full")
class KittiComb2015Test(Kitti_comb_test):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2015 = os.path.join(root, "data_scene_flow", "testing", "image_2")
super(KittiComb2015Test, self).__init__(
args,
images_root_2015=images_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop)
class KittiComb2012Train(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiComb2012Train, self).__init__(
args,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="train")
class KittiComb2012Val(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiComb2012Val, self).__init__(
args,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="valid")
class KittiComb2012Full(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiComb2012Full, self).__init__(
args,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="full")
class KittiComb2012Test(Kitti_comb_test):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2012 = os.path.join(root, "data_stereo_flow", "testing", "colored_0")
super(KittiComb2012Test, self).__init__(
args,
images_root_2012=images_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop)
================================================
FILE: datasets/sintel.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
import tools
VALIDATE_INDICES = [
199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 340, 341, 342, 343, 344,
345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356,
357, 358, 359, 360, 361, 362, 363, 364, 536, 537, 538, 539,
540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551,
552, 553, 554, 555, 556, 557, 558, 559, 560, 659, 660, 661,
662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,
674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685,
686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697,
967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978,
979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990,
991]
class _Sintel(data.Dataset):
def __init__(self,
args,
dir_root=None,
photometric_augmentations=False,
imgtype=None,
dstype=None):
self._args = args
images_root = os.path.join(dir_root, imgtype)
if imgtype is "comb":
images_root = os.path.join(dir_root, "clean")
flow_root = os.path.join(dir_root, "flow")
occ_root = os.path.join(dir_root, "occlusions_rev")
if not os.path.isdir(images_root):
raise ValueError("Image directory '%s' not found!")
if flow_root is not None and not os.path.isdir(flow_root):
raise ValueError("Flow directory '%s' not found!")
if occ_root is not None and not os.path.isdir(occ_root):
raise ValueError("Occ directory '%s' not found!")
all_flo_filenames = sorted(glob(os.path.join(flow_root, "*/*.flo")))
all_occ_filenames = sorted(glob(os.path.join(occ_root, "*/*.png")))
all_img_filenames = sorted(glob(os.path.join(images_root, "*/*.png")))
# Remember base for substraction at runtime
# e.g. subtract_base = "/home/user/.../MPI-Sintel-Complete/training/clean"
self._substract_base = tools.cd_dotdot(images_root)
# ------------------------------------------------------------------------
# Get unique basenames
# ------------------------------------------------------------------------
# e.g. base_folders = [alley_1", "alley_2", "ambush_2", ...]
substract_full_base = tools.cd_dotdot(all_img_filenames[0])
base_folders = sorted(list(set([
os.path.dirname(fn.replace(substract_full_base, ""))[1:] for fn in all_img_filenames
])))
self._image_list = []
self._flow_list = []
self._occ_list = []
for base_folder in base_folders:
img_filenames = [x for x in all_img_filenames if base_folder in x]
flo_filenames = [x for x in all_flo_filenames if base_folder in x]
occ_filenames = [x for x in all_occ_filenames if base_folder in x]
for i in range(len(img_filenames) - 1):
im1 = img_filenames[i]
im2 = img_filenames[i + 1]
flo = flo_filenames[i]
occ = occ_filenames[i]
self._image_list += [[im1, im2]]
self._flow_list += [flo]
self._occ_list += [occ]
# Sanity check
im1_base_filename = os.path.splitext(os.path.basename(im1))[0]
im2_base_filename = os.path.splitext(os.path.basename(im2))[0]
flo_base_filename = os.path.splitext(os.path.basename(flo))[0]
occ_base_filename = os.path.splitext(os.path.basename(occ))[0]
im1_frame, im1_no = im1_base_filename.split("_")
im2_frame, im2_no = im2_base_filename.split("_")
assert(im1_frame == im2_frame)
assert(int(im1_no) == int(im2_no) - 1)
flo_frame, flo_no = flo_base_filename.split("_")
assert(im1_frame == flo_frame)
assert(int(im1_no) == int(flo_no))
occ_frame, occ_no = occ_base_filename.split("_")
assert(im1_frame == occ_frame)
assert(int(im1_no) == int(occ_no))
assert len(self._image_list) == len(self._flow_list)
assert len(self._image_list) == len(self._occ_list)
# -------------------------------------------------------------
# Remove invalid validation indices
# -------------------------------------------------------------
full_num_examples = len(self._image_list)
validate_indices = [x for x in VALIDATE_INDICES if x in range(full_num_examples)]
# ----------------------------------------------------------
# Construct list of indices for training/validation
# ----------------------------------------------------------
list_of_indices = None
if dstype == "train":
list_of_indices = [x for x in range(full_num_examples) if x not in validate_indices]
elif dstype == "valid":
list_of_indices = validate_indices
elif dstype == "full":
list_of_indices = range(full_num_examples)
else:
raise ValueError("dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = [self._image_list[i] for i in list_of_indices]
self._flow_list = [self._flow_list[i] for i in list_of_indices]
self._occ_list = [self._occ_list[i] for i in list_of_indices]
if imgtype is "comb":
image_list_final = [[val[0].replace("clean", "final"), val[1].replace("clean", "final")] for idx, val in enumerate(self._image_list)]
self._image_list += image_list_final
self._flow_list += self._flow_list
self._occ_list += self._occ_list
assert len(self._image_list) == len(self._flow_list)
assert len(self._image_list) == len(self._occ_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
self._size = len(self._image_list)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_filename = self._flow_list[index]
occ_filename = self._occ_list[index]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_np0 = common.read_flo_as_float32(flo_filename)
occ_np0 = common.read_occ_image_as_float32(occ_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
flo = common.numpy2torch(flo_np0)
occ = common.numpy2torch(occ_np0)
# e.g. "clean/alley_1/"
basedir = os.path.splitext(os.path.dirname(im1_filename).replace(self._substract_base, "")[1:])[0]
# example filename
basename = os.path.splitext(os.path.basename(im1_filename))[0]
example_dict = {
"input1": im1,
"input2": im2,
"index": index,
"basedir": basedir,
"basename": basename,
"target1": flo,
"target_occ1": occ
}
return example_dict
def __len__(self):
return self._size
class _Sintel_test(data.Dataset):
def __init__(self,
args,
dir_root=None,
photometric_augmentations=False,
imgtype=None):
self._args = args
images_root = os.path.join(dir_root, imgtype)
if not os.path.isdir(images_root):
raise ValueError("Image directory '%s' not found!")
all_img_filenames = sorted(glob(os.path.join(images_root, "*/*.png")))
# Remember base for substraction at runtime
# e.g. subtract_base = "/home/user/.../MPI-Sintel-Complete/training/clean"
self._substract_base = tools.cd_dotdot(images_root)
# ------------------------------------------------------------------------
# Get unique basenames
# ------------------------------------------------------------------------
# e.g. base_folders = [alley_1", "alley_2", "ambush_2", ...]
substract_full_base = tools.cd_dotdot(all_img_filenames[0])
base_folders = sorted(list(set([
os.path.dirname(fn.replace(substract_full_base, ""))[1:] for fn in all_img_filenames
])))
self._image_list = []
for base_folder in base_folders:
img_filenames = [x for x in all_img_filenames if base_folder in x]
for i in range(len(img_filenames) - 1):
im1 = img_filenames[i]
im2 = img_filenames[i + 1]
self._image_list += [[im1, im2]]
# Sanity check
im1_base_filename = os.path.splitext(os.path.basename(im1))[0]
im2_base_filename = os.path.splitext(os.path.basename(im2))[0]
im1_frame, im1_no = im1_base_filename.split("_")
im2_frame, im2_no = im2_base_filename.split("_")
assert(im1_frame == im2_frame)
assert(int(im1_no) == int(im2_no) - 1)
full_num_examples = len(self._image_list)
list_of_indices = range(full_num_examples)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = [self._image_list[i] for i in list_of_indices]
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
self._size = len(self._image_list)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# e.g. "clean/alley_1/"
basedir = os.path.splitext(os.path.dirname(im1_filename).replace(self._substract_base, "")[1:])[0]
# example filename
basename = os.path.splitext(os.path.basename(im1_filename))[0]
example_dict = {
"input1": im1,
"input2": im2,
"index": index,
"basedir": basedir,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class SintelTrainingCleanTrain(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCleanTrain, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean",
dstype="train")
class SintelTrainingCleanValid(_Sintel):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "training")
super(SintelTrainingCleanValid, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean",
dstype="valid")
class SintelTrainingCleanFull(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCleanFull, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean",
dstype="full")
class SintelTrainingFinalTrain(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingFinalTrain, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final",
dstype="train")
class SintelTrainingFinalValid(_Sintel):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "training")
super(SintelTrainingFinalValid, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final",
dstype="valid")
class SintelTrainingFinalFull(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingFinalFull, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final",
dstype="full")
class SintelTrainingCombTrain(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCombTrain, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="comb",
dstype="train")
class SintelTrainingCombValid(_Sintel):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "training")
super(SintelTrainingCombValid, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="comb",
dstype="valid")
class SintelTrainingCombFull(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCombFull, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="comb",
dstype="full")
class SintelTestClean(_Sintel_test):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "test")
super(SintelTestClean, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean")
class SintelTestFinal(_Sintel_test):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "test")
super(SintelTestFinal, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final")
================================================
FILE: datasets/transforms.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import numpy as np
import torch
def image_random_gamma(image, min_gamma=0.7, max_gamma=1.5, clip_image=False):
gamma = np.random.uniform(min_gamma, max_gamma)
adjusted = torch.pow(image, gamma)
if clip_image:
adjusted.clamp_(0.0, 1.0)
return adjusted
class RandomGamma:
def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False):
self._min_gamma = min_gamma
self._max_gamma = max_gamma
self._clip_image = clip_image
def __call__(self, image):
return image_random_gamma(
image,
min_gamma=self._min_gamma,
max_gamma=self._max_gamma,
clip_image=self._clip_image)
# ------------------------------------------------------------------
# Allow transformation chains of the type:
# im1, im2, .... = transform(im1, im2, ...)
# ------------------------------------------------------------------
class TransformChainer:
def __init__(self, list_of_transforms):
self._list_of_transforms = list_of_transforms
def __call__(self, *args):
list_of_args = list(args)
for transform in self._list_of_transforms:
list_of_args = [transform(arg) for arg in list_of_args]
if len(args) == 1:
return list_of_args[0]
else:
return list_of_args
# ------------------------------------------------------------------
# Allow transformation chains of the type:
# im1, im2, .... = split( transform( concatenate(im1, im2, ...) ))
# ------------------------------------------------------------------
class ConcatTransformSplitChainer:
def __init__(self, list_of_transforms, from_numpy=True, to_numpy=False):
self._chainer = TransformChainer(list_of_transforms)
self._from_numpy = from_numpy
self._to_numpy = to_numpy
def __call__(self, *args):
num_splits = len(args)
if self._from_numpy:
concatenated = np.concatenate(args, axis=0)
else:
concatenated = torch.cat(args, dim=1)
transformed = self._chainer(concatenated)
if self._to_numpy:
split = np.split(transformed, indices_or_sections=num_splits, axis=0)
else:
split = torch.chunk(transformed, num_splits, dim=1)
return split
================================================
FILE: flyingchairsocc/README.md
================================================
# FlyingChairsOcc dataset
The FlyingChairsOcc dataset is an extended version of the Flying Chairs Dataset, including bi-directional optical flow ground truth and two occlusion maps for each image.
You may also find that another concurrent dataset, the Flying Chairs 2 Dataset, is useful.
## License agreement
This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
1. That the dataset comes “AS IS”, without express or implied warranty. Although every effort has been made to ensure accuracy, we (TU Darmstadt) do not accept any responsibility for errors or omissions.
2. That you include a reference to the FlyingChairsOcc Dataset in any work that makes use of the dataset. For research papers, cite our publication: J. Hur and S. Roth, “Iterative residual refinement for joint optical flow and occlusion estimation,” in Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Long Beach, California, June 2019
3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
5. That all rights not expressly granted to you are reserved by us (TU Darmstadt).
## Download link
Download
(82GB)
## Reference
Please cite the paper below if you find the dataset and source codes are useful.
@inproceedings{Hur:2019:IRR,
Author = {Junhwa Hur and Stefan Roth},
Booktitle = {CVPR},
Title = {Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation},
Year = {2019}
}
Contact: junhwa.hur[at]visinf.tu-darmstadt.de
================================================
FILE: install.sh
================================================
#!/bin/bash
cd ./models/correlation_package
python setup.py install
cd ..
================================================
FILE: logger.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import colorama
import logging
import os
import re
import tools
import sys
def get_default_logging_format(colorize=False, brackets=False):
style = colorama.Style.DIM if colorize else ''
# color = colorama.Fore.CYAN if colorize else ''
color = colorama.Fore.WHITE if colorize else ''
reset = colorama.Style.RESET_ALL if colorize else ''
if brackets:
result = "{}{}[%(asctime)s]{} %(message)s".format(style, color, reset)
else:
result = "{}{}%(asctime)s{} %(message)s".format(style, color, reset)
return result
def get_default_logging_datefmt():
return "%Y-%m-%d %H:%M:%S"
def log_module_info(module):
lines = module.__str__().split("\n")
for line in lines:
logging.info(line)
class LogbookFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None):
super(LogbookFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
self._re = re.compile(r"\033\[[0-9]+m")
def remove_colors_from_msg(self, msg):
msg = re.sub(self._re, "", msg)
return msg
def format(self, record=None):
record.msg = self.remove_colors_from_msg(record.msg)
return super(LogbookFormatter, self).format(record)
class ConsoleFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None):
super(ConsoleFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
def format(self, record=None):
indent = sys.modules[__name__].global_indent
record.msg = " " * indent + record.msg
return super(ConsoleFormatter, self).format(record)
class SkipLogbookFilter(logging.Filter):
def filter(self, record):
return record.levelno != logging.LOGBOOK
def configure_logging(filename=None):
# set global indent level
sys.modules[__name__].global_indent = 0
# add custom tqdm logger
tools.addLoggingLevel("LOGBOOK", 1000)
# create logger
root_logger = logging.getLogger("")
root_logger.setLevel(logging.INFO)
# create console handler and set level to debug
console = logging.StreamHandler()
console.setLevel(logging.INFO)
fmt = get_default_logging_format(colorize=True, brackets=False)
datefmt = get_default_logging_datefmt()
formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt)
console.setFormatter(formatter)
# Skip logging.tqdm requests for console outputs
skip_logbook_filter = SkipLogbookFilter()
console.addFilter(skip_logbook_filter)
# add console to root_logger
root_logger.addHandler(console)
# add logbook
if filename is not None:
# ensure dir
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
# --------------------------------------------------------------------------------------
# Configure handler that removes color codes from logbook
# --------------------------------------------------------------------------------------
logbook = logging.FileHandler(filename=filename, mode="a", encoding="utf-8")
logbook.setLevel(logging.INFO)
fmt = get_default_logging_format(colorize=False, brackets=True)
logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt)
logbook.setFormatter(logbook_formatter)
root_logger.addHandler(logbook)
class LoggingBlock:
def __init__(self, title, emph=False):
self._emph = emph
bright = colorama.Style.BRIGHT
cyan = colorama.Fore.CYAN
reset = colorama.Style.RESET_ALL
if emph:
logging.info("%s==>%s %s%s%s" % (cyan, reset, bright, title, reset))
else:
logging.info(title)
def __enter__(self):
sys.modules[__name__].global_indent += 2
return self
def __exit__(self, exc_type, exc_value, traceback):
sys.modules[__name__].global_indent -= 2
================================================
FILE: losses.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as tf
def _elementwise_epe(input_flow, target_flow):
residual = target_flow - input_flow
return torch.norm(residual, p=2, dim=1, keepdim=True)
def _elementwise_robust_epe_char(input_flow, target_flow):
residual = target_flow - input_flow
return torch.pow(torch.norm(residual, p=2, dim=1, keepdim=True) + 0.01, 0.4)
def _downsample2d_as(inputs, target_as):
_, _, h, w = target_as.size()
return tf.adaptive_avg_pool2d(inputs, [h, w])
def _upsample2d_as(inputs, target_as, mode="bilinear"):
_, _, h, w = target_as.size()
return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
def f1_score(y_true, y_pred):
return fbeta_score(y_true, y_pred, 1)
def fbeta_score(y_true, y_pred, beta, eps=1e-8):
beta2 = beta ** 2
y_pred = y_pred.float()
y_true = y_true.float()
true_positive = (y_pred * y_true).sum(dim=2).sum(dim=2)
precision = true_positive / (y_pred.sum(dim=2).sum(dim=2) + eps)
recall = true_positive / (y_true.sum(dim=2).sum(dim=2) + eps)
return torch.mean(precision * recall / (precision * beta2 + recall + eps) * (1 + beta2))
def f1_score_bal_loss(y_pred, y_true):
eps = 1e-8
tp = -(y_true * torch.log(y_pred + eps)).sum(dim=2).sum(dim=2).sum(dim=1)
fn = -((1 - y_true) * torch.log((1 - y_pred) + eps)).sum(dim=2).sum(dim=2).sum(dim=1)
denom_tp = y_true.sum(dim=2).sum(dim=2).sum(dim=1) + y_pred.sum(dim=2).sum(dim=2).sum(dim=1) + eps
denom_fn = (1 - y_true).sum(dim=2).sum(dim=2).sum(dim=1) + (1 - y_pred).sum(dim=2).sum(dim=2).sum(dim=1) + eps
return ((tp / denom_tp).sum() + (fn / denom_fn).sum()) * y_pred.size(2) * y_pred.size(3) * 0.5
class MultiScaleEPE_FlowNet(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
# div_flow trick
target = self._args.model_div_flow * target_dict["target1"]
total_loss = 0
for i, output_i in enumerate(outputs):
target_i = _downsample2d_as(target, output_i)
epe_i = _elementwise_epe(output_i, target_i)
total_loss = total_loss + self._weights[i] * epe_i.sum() / self._batch_size
loss_dict["epe%i" % (i + 2)] = epe_i.mean()
loss_dict["total_loss"] = total_loss
else:
output = output_dict["flow1"]
target = target_dict["target1"]
epe = _elementwise_epe(output, target)
loss_dict["epe"] = epe.mean()
return loss_dict
class MultiScaleEPE_FlowNet_IRR(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
total_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj, target_f_ii)
total_loss = total_loss + self._weights[ii] * epe_f_ii.sum()
loss_dict["epe%i" % (ii + 2)] = epe_f_ii.mean()
loss_dict["total_loss"] = total_loss / self._batch_size / self._num_iters
else:
output = output_dict["flow1"]
target_f = target_dict["target1"]
epe_f = _elementwise_epe(target_f, output)
loss_dict["epe"] = epe_f.mean()
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Bi(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Bi, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
target_b = self._args.model_div_flow * target_dict["target2"]
total_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0][0])
target_b_ii = _downsample2d_as(target_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)
epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)
total_loss = total_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum())
loss_dict["epe%i" % (ii + 2)] = (epe_f_ii.mean() + epe_b_ii.mean()) / 2
loss_dict["total_loss"] = total_loss / self._batch_size / self._num_iters / 2
else:
epe_f = _elementwise_epe(output_dict["flow1"], target_dict["target1"])
loss_dict["epe"] = epe_f.mean()
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Occ(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Occ, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
self.f1_score_bal_loss = f1_score_bal_loss
self.occ_activ = nn.Sigmoid()
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
outputs_occ = [output_dict[key] for key in ["occ2", "occ3", "occ4", "occ5", "occ6"]]
# div_flow trick
target = self._args.model_div_flow * target_dict["target1"]
target_occ = target_dict["target_occ1"]
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_ii = _downsample2d_as(target, output_ii[0])
for jj, output_ii_jj in enumerate(output_ii):
flow_loss = flow_loss + self._weights[ii] * _elementwise_epe(output_ii_jj, target_ii).sum()
for ii, output_ii in enumerate(outputs_occ):
target_occ_f = _downsample2d_as(target_occ, output_ii[0])
for jj, output_ii_jj in enumerate(output_ii):
occ_loss = occ_loss + self._weights[ii] * self.f1_score_bal_loss(self.occ_activ(output_ii_jj), target_occ_f)
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size / self._num_iters
loss_dict["occ_loss"] = occ_loss / self._batch_size / self._num_iters
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / self._num_iters
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow1"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ1"])))
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Bi_Occ(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Bi_Occ, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
self.f1_score_bal_loss = f1_score_bal_loss
self.occ_activ = nn.Sigmoid()
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
outputs_occ = [output_dict[key] for key in ["occ2", "occ3", "occ4", "occ5", "occ6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
target_b = self._args.model_div_flow * target_dict["target2"]
target_occ_f = target_dict["target_occ1"]
target_occ_b = target_dict["target_occ2"]
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0][0])
target_b_ii = _downsample2d_as(target_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)
epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)
flow_loss = flow_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum()) * 0.5
for ii, output_ii in enumerate(outputs_occ):
target_occ_f = _downsample2d_as(target_occ_f, output_ii[0][0])
target_occ_b = _downsample2d_as(target_occ_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
output_occ_f = self.occ_activ(output_ii_jj[0])
output_occ_b = self.occ_activ(output_ii_jj[1])
bce_f_ii = self.f1_score_bal_loss(output_occ_f, target_occ_f)
bce_b_ii = self.f1_score_bal_loss(output_occ_b, target_occ_b)
occ_loss = occ_loss + self._weights[ii] * (bce_f_ii + bce_b_ii) * 0.5
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size / self._num_iters
loss_dict["occ_loss"] = occ_loss / self._batch_size / self._num_iters
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / self._num_iters
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow1"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ1"])))
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.0003125, 0.00125, 0.005, 0.01, 0.02, 0.08, 0.32]
self.occ_activ = nn.Sigmoid()
self.f1_score_bal_loss = f1_score_bal_loss
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow", "flow1", "flow2", "flow3", "flow4", "flow5", "flow6"]]
outputs_occ = [output_dict[key] for key in ["occ", "occ1", "occ2", "occ3", "occ4", "occ5", "occ6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
target_b = self._args.model_div_flow * target_dict["target2"]
target_occ_f = target_dict["target_occ1"]
target_occ_b = target_dict["target_occ2"]
num_iters = len(outputs_flo[0])
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0][0])
target_b_ii = _downsample2d_as(target_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)
epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)
flow_loss = flow_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum()) * 0.5
for ii, output_ii in enumerate(outputs_occ):
target_occ_f = _downsample2d_as(target_occ_f, output_ii[0][0])
target_occ_b = _downsample2d_as(target_occ_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
output_occ_f = self.occ_activ(output_ii_jj[0])
output_occ_b = self.occ_activ(output_ii_jj[1])
bce_f_ii = self.f1_score_bal_loss(output_occ_f, target_occ_f)
bce_b_ii = self.f1_score_bal_loss(output_occ_b, target_occ_b)
occ_loss = occ_loss + self._weights[ii] * (bce_f_ii + bce_b_ii) * 0.5
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size / num_iters
loss_dict["occ_loss"] = occ_loss / self._batch_size / num_iters
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / num_iters
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ"])))
return loss_dict
class MultiScaleEPE_PWC(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs = output_dict['flow']
# div_flow trick
target = self._args.model_div_flow * target_dict["target1"]
total_loss = 0
for ii, output_ii in enumerate(outputs):
loss_ii = _elementwise_epe(output_ii, _downsample2d_as(target, output_ii)).sum()
total_loss = total_loss + self._weights[ii] * loss_ii
loss_dict["total_loss"] = total_loss / self._batch_size
else:
epe = _elementwise_epe(output_dict["flow"], target_dict["target1"])
loss_dict["epe"] = epe.mean()
return loss_dict
class MultiScaleEPE_PWC_Bi(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC_Bi, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs = output_dict['flow']
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
target_b = self._args.model_div_flow * target_dict["target2"]
total_loss = 0
for i, output_i in enumerate(outputs):
epe_i_f = _elementwise_epe(output_i[0], _downsample2d_as(target_f, output_i[0]))
epe_i_b = _elementwise_epe(output_i[1], _downsample2d_as(target_b, output_i[1]))
total_loss = total_loss + self._weights[i] * (epe_i_f.sum() + epe_i_b.sum())
loss_dict["total_loss"] = total_loss / (2 * self._batch_size)
else:
epe = _elementwise_epe(output_dict["flow"], target_dict["target1"])
loss_dict["epe"] = epe.mean()
return loss_dict
class MultiScaleEPE_PWC_Occ(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC_Occ, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]
self.occ_activ = nn.Sigmoid()
self.f1_score_bal_loss = f1_score_bal_loss
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
output_flo = output_dict['flow']
output_occ = output_dict['occ']
# div_flow trick
target_flo = self._args.model_div_flow * target_dict["target1"]
target_occ = target_dict["target_occ1"]
flow_loss = 0
occ_loss = 0
for i, output_i in enumerate(output_flo):
flow_loss = flow_loss + self._weights[i] * _elementwise_epe(output_i, _downsample2d_as(target_flo, output_i)).sum()
for i, output_i in enumerate(output_occ):
output_occ = self.occ_activ(output_i)
occ_loss = occ_loss + self._weights[i] * self.f1_score_bal_loss(output_occ, _downsample2d_as(target_occ, output_occ))
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size
loss_dict["occ_loss"] = occ_loss / self._batch_size
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ"])))
return loss_dict
class MultiScaleEPE_PWC_Bi_Occ(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC_Bi_Occ, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]
self.occ_activ = nn.Sigmoid()
self.f1_score_bal_loss = f1_score_bal_loss
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
output_flo = output_dict['flow']
output_occ = output_dict['occ']
# div_flow trick
target_flo_f = self._args.model_div_flow * target_dict["target1"]
target_flo_b = self._args.model_div_flow * target_dict["target2"]
target_occ_f = target_dict["target_occ1"]
target_occ_b = target_dict["target_occ2"]
# bchw
flow_loss = 0
occ_loss = 0
for i, output_i in enumerate(output_flo):
flow_loss = flow_loss + self._weights[i] * _elementwise_epe(output_i[0], _downsample2d_as(target_flo_f, output_i[0])).sum()
flow_loss = flow_loss + self._weights[i] * _elementwise_epe(output_i[1], _downsample2d_as(target_flo_b, output_i[1])).sum()
for i, output_i in enumerate(output_occ):
output_occ_f = self.occ_activ(output_i[0])
output_occ_b = self.occ_activ(output_i[1])
occ_loss = occ_loss + self._weights[i] * self.f1_score_bal_loss(output_occ_f, _downsample2d_as(target_occ_f, output_occ_f))
occ_loss = occ_loss + self._weights[i] * self.f1_score_bal_loss(output_occ_b, _downsample2d_as(target_occ_b, output_occ_b))
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / (2 * self._batch_size)
loss_dict["occ_loss"] = occ_loss / (2 * self._batch_size)
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / (2 * self._batch_size)
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ"])))
return loss_dict
class MultiScaleEPE_PWC_Bi_Occ_upsample(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC_Bi_Occ_upsample, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.32, 0.08, 0.02, 0.01, 0.005, 0.00125, 0.0003125]
self.occ_activ = nn.Sigmoid()
self.f1_score_bal_loss = f1_score_bal_loss
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
output_flo = output_dict['flow']
output_occ = output_dict['occ']
# div_flow trick
target_flo_f = self._args.model_div_flow * target_dict["target1"]
target_flo_b = self._args.model_div_flow * target_dict["target2"]
target_occ_f = target_dict["target_occ1"]
target_occ_b = target_dict["target_occ2"]
# bchw
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(output_flo):
loss_ii = 0
for jj in range(0, len(output_ii) // 2):
loss_ii = loss_ii + _elementwise_epe(output_ii[2 * jj], _downsample2d_as(target_flo_f, output_ii[2 * jj])).sum()
loss_ii = loss_ii + _elementwise_epe(output_ii[2 * jj + 1], _downsample2d_as(target_flo_b, output_ii[2 * jj + 1])).sum()
flow_loss = flow_loss + self._weights[ii] * loss_ii / len(output_ii)
for ii, output_ii in enumerate(output_occ):
loss_ii = 0
for jj in range(0, len(output_ii) // 2):
output_occ_f = self.occ_activ(output_ii[2 * jj])
output_occ_b = self.occ_activ(output_ii[2 * jj + 1])
loss_ii = loss_ii + self.f1_score_bal_loss(output_occ_f, _downsample2d_as(target_occ_f, output_occ_f))
loss_ii = loss_ii + self.f1_score_bal_loss(output_occ_b, _downsample2d_as(target_occ_b, output_occ_b))
occ_loss = occ_loss + self._weights[ii] * loss_ii / len(output_ii)
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size
loss_dict["occ_loss"] = occ_loss / self._batch_size
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ"])))
return loss_dict
class MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.32, 0.08, 0.02, 0.01, 0.005, 0.00125, 0.0003125]
self.occ_activ = nn.Sigmoid()
self.occ_loss_bce = nn.BCELoss(reduction='sum')
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
output_flo = output_dict['flow']
output_occ = output_dict['occ']
# div_flow trick
target_flo_f = self._args.model_div_flow * target_dict["target1"]
target_occ_f = target_dict["target_occ1"]
# bchw
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(output_flo):
loss_ii = 0
for jj in range(0, len(output_ii) // 2):
loss_ii = loss_ii + _elementwise_robust_epe_char(output_ii[2 * jj], _downsample2d_as(target_flo_f, output_ii[2 * jj])).sum()
output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()
flow_loss = flow_loss + self._weights[ii] * loss_ii / len(output_ii) * 2
for ii, output_ii in enumerate(output_occ):
loss_ii = 0
for jj in range(0, len(output_ii) // 2):
output_occ_f = self.occ_activ(output_ii[2 * jj])
output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()
loss_ii = loss_ii + self.occ_loss_bce(output_occ_f, _downsample2d_as(target_occ_f, output_occ_f))
occ_loss = occ_loss + self._weights[ii] * loss_ii / len(output_ii) * 2
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size
loss_dict["occ_loss"] = occ_loss / self._batch_size
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ"])))
return loss_dict
class MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.001, 0.001, 0.001, 0.002, 0.004, 0.004, 0.004]
self.occ_activ = nn.Sigmoid()
def forward(self, output_dict, target_dict):
loss_dict = {}
valid_mask = target_dict["input_valid"]
b, _, h, w = target_dict["target1"].size()
if self.training:
output_flo = output_dict['flow']
output_occ = output_dict['occ']
# div_flow trick
target_flo_f = self._args.model_div_flow * target_dict["target1"]
# bchw
flow_loss = 0
for ii, output_ii in enumerate(output_flo):
loss_ii = 0
for jj in range(0, len(output_ii) // 2):
valid_epe = _elementwise_robust_epe_char(_upsample2d_as(output_ii[2 * jj], target_flo_f), target_flo_f) * valid_mask
for bb in range(0, b):
valid_epe[bb, ...][valid_mask[bb, ...] == 0] = valid_epe[bb, ...][valid_mask[bb, ...] == 0].detach()
norm_const = h * w / (valid_mask[bb, ...].sum())
loss_ii = loss_ii + valid_epe[bb, ...][valid_mask[bb, ...] != 0].sum() * norm_const
output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()
flow_loss = flow_loss + self._weights[ii] * loss_ii / len(output_ii) * 2
for ii, output_ii in enumerate(output_occ):
for jj in range(0, len(output_ii) // 2):
output_ii[2 * jj] = output_ii[2 * jj].detach()
output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()
loss_dict["flow_loss"] = flow_loss / self._batch_size
loss_dict["total_loss"] = flow_loss / self._batch_size
else:
flow_gt_mag = torch.norm(target_dict["target1"], p=2, dim=1, keepdim=True) + 1e-8
flow_epe = _elementwise_epe(output_dict["flow"], target_dict["target1"]) * valid_mask
epe_per_image = (flow_epe.view(b, -1).sum(1)) / (valid_mask.view(b, -1).sum(1))
loss_dict["epe"] = epe_per_image.mean()
outlier_epe = (flow_epe > 3).float() * ((flow_epe / flow_gt_mag) > 0.05).float() * valid_mask
outlier_per_image = (outlier_epe.view(b, -1).sum(1)) / (valid_mask.view(b, -1).sum(1))
loss_dict["outlier"] = outlier_per_image.mean()
return loss_dict
================================================
FILE: main.py
================================================
from __future__ import absolute_import, division, print_function
import os
import subprocess
import commandline
import configuration as config
import runtime
import logger
import logging
import tools
import torch
def main():
# Change working directory
os.chdir(os.path.dirname(os.path.realpath(__file__)))
# Parse commandline arguments
args = commandline.setup_logging_and_parse_arguments(blocktitle="Commandline Arguments")
# Set random seed, possibly on Cuda
config.configure_random_seed(args)
# DataLoader
train_loader, validation_loader, inference_loader = config.configure_data_loaders(args)
success = any(loader is not None for loader in [train_loader, validation_loader, inference_loader])
if not success:
logging.info("No dataset could be loaded successfully. Please check dataset paths!")
quit()
# Configure data augmentation
training_augmentation, validation_augmentation = config.configure_runtime_augmentations(args)
# Configure model and loss
model_and_loss = config.configure_model_and_loss(args)
# Resume from checkpoint if available
checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(args, model_and_loss)
# Checkpoint and save directory
with logger.LoggingBlock("Save Directory", emph=True):
logging.info("Save directory: %s" % args.save)
if not os.path.exists(args.save):
os.makedirs(args.save)
# # Multi-GPU automation
# with logger.LoggingBlock("Multi GPU", emph=True):
# if torch.cuda.device_count() > 1:
# logging.info("Let's use %d GPUs!" % torch.cuda.device_count())
# model_and_loss._model = torch.nn.DataParallel(model_and_loss._model)
# else:
# logging.info("Let's use %d GPU!" % torch.cuda.device_count())
# Configure optimizer
optimizer = config.configure_optimizer(args, model_and_loss)
# Configure learning rate
lr_scheduler = config.configure_lr_scheduler(args, optimizer)
# If this is just an evaluation: overwrite savers and epochs
if args.evaluation:
args.start_epoch = 1
args.total_epochs = 1
train_loader = None
checkpoint_saver = None
optimizer = None
lr_scheduler = None
# Cuda optimization
if args.cuda:
torch.backends.cudnn.benchmark = True
# Kickoff training, validation and/or testing
return runtime.exec_runtime(
args,
checkpoint_saver=checkpoint_saver,
model_and_loss=model_and_loss,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_loader=train_loader,
validation_loader=validation_loader,
inference_loader=inference_loader,
training_augmentation=training_augmentation,
validation_augmentation=validation_augmentation)
if __name__ == "__main__":
main()
================================================
FILE: models/IRR_FlowNet.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .flownet_modules import conv, deconv
from .flownet_modules import concatenate_as, upsample2d_as
from .flownet_modules import initialize_msra
from .flownet_modules import WarpingLayer
from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc
class FlowNetS(nn.Module):
def __init__(self, args):
super(FlowNetS, self).__init__()
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv3_1 = make_conv( 256, 256, kernel_size=3, stride=1)
self._conv4 = make_conv( 256, 512, kernel_size=3, stride=2)
self._conv4_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv5 = make_conv( 512, 512, kernel_size=3, stride=2)
self._conv5_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv6 = make_conv( 512, 1024, kernel_size=3, stride=2)
self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)
def make_deconv(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=True, bias=False)
self._deconv5 = make_deconv(1024 , 512)
self._deconv4 = make_deconv(1024 + 2, 256)
self._deconv3 = make_deconv( 768 + 2, 128)
self._deconv2 = make_deconv( 384 + 2, 64)
self._deconv_occ5 = make_deconv(1024 , 512)
self._deconv_occ4 = make_deconv(1024 + 1, 256)
self._deconv_occ3 = make_deconv( 768 + 1, 128)
self._deconv_occ2 = make_deconv( 384 + 1, 64)
def make_predict(in_planes, out_planes):
return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,
nonlinear=False, bias=True)
self._predict_flow6 = make_predict(1024 , 2)
self._predict_flow5 = make_predict(1024 + 2, 2)
self._predict_flow4 = make_predict( 768 + 2, 2)
self._predict_flow3 = make_predict( 384 + 2, 2)
self._predict_flow2 = make_predict( 128 + 2, 2)
self._predict_occ6 = make_predict(1024 , 1)
self._predict_occ5 = make_predict(1024 + 1, 1)
self._predict_occ4 = make_predict( 768 + 1, 1)
self._predict_occ3 = make_predict( 384 + 1, 1)
self._predict_occ2 = make_predict( 128 + 1, 1)
def make_upsample(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=False, bias=False)
self._upsample_flow6_to_5 = make_upsample(2, 2)
self._upsample_flow5_to_4 = make_upsample(2, 2)
self._upsample_flow4_to_3 = make_upsample(2, 2)
self._upsample_flow3_to_2 = make_upsample(2, 2)
self._upsample_occ6_to_5 = make_upsample(1, 1)
self._upsample_occ5_to_4 = make_upsample(1, 1)
self._upsample_occ4_to_3 = make_upsample(1, 1)
self._upsample_occ3_to_2 = make_upsample(1, 1)
def forward(self, conv2_im1, conv3_im1, conv3_im2):
conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)
conv3_1 = self._conv3_1(conv_concat3)
conv4_1 = self._conv4_1(self._conv4(conv3_1))
conv5_1 = self._conv5_1(self._conv5(conv4_1))
conv6_1 = self._conv6_1(self._conv6(conv5_1))
# Flow Decoder
predict_flow6 = self._predict_flow6(conv6_1)
upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)
deconv5 = self._deconv5(conv6_1)
concat5 = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)
predict_flow5 = self._predict_flow5(concat5)
upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)
deconv4 = self._deconv4(concat5)
concat4 = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)
predict_flow4 = self._predict_flow4(concat4)
upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)
deconv3 = self._deconv3(concat4)
concat3 = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)
predict_flow3 = self._predict_flow3(concat3)
upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)
deconv2 = self._deconv2(concat3)
concat2 = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)
predict_flow2 = self._predict_flow2(concat2)
# Occ Decoder
predict_occ6 = self._predict_occ6(conv6_1)
upsampled_occ6_to_5 = self._upsample_occ6_to_5(predict_occ6)
deconv_occ5 = self._deconv_occ5(conv6_1)
concat_occ5 = concatenate_as((conv5_1, deconv_occ5, upsampled_occ6_to_5), conv5_1, dim=1)
predict_occ5 = self._predict_occ5(concat_occ5)
upsampled_occ5_to_4 = self._upsample_occ5_to_4(predict_occ5)
deconv_occ4 = self._deconv_occ4(concat_occ5)
concat_occ4 = concatenate_as((conv4_1, deconv_occ4, upsampled_occ5_to_4), conv4_1, dim=1)
predict_occ4 = self._predict_occ4(concat_occ4)
upsampled_occ4_to_3 = self._upsample_occ4_to_3(predict_occ4)
deconv_occ3 = self._deconv_occ3(concat_occ4)
concat_occ3 = concatenate_as((conv3_1, deconv_occ3, upsampled_occ4_to_3), conv3_1, dim=1)
predict_occ3 = self._predict_occ3(concat_occ3)
upsampled_occ3_to_2 = self._upsample_occ3_to_2(predict_occ3)
deconv_occ2 = self._deconv_occ2(concat_occ3)
concat_occ2 = concatenate_as((conv2_im1, deconv_occ2, upsampled_occ3_to_2), conv2_im1, dim=1)
predict_occ2 = self._predict_occ2(concat_occ2)
return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6, predict_occ2, predict_occ3, predict_occ4, predict_occ5, predict_occ6
class FlowNet1S(nn.Module):
def __init__(self, args, div_flow=0.05):
super(FlowNet1S, self).__init__()
self._flownets = FlowNetS(args)
self._warping_layer = WarpingLayer()
self._div_flow = div_flow
self._num_iters = args.num_iters
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv1 = make_conv( 3, 32, kernel_size=7, stride=2)
self._conv2 = make_conv( 32, 64, kernel_size=5, stride=2)
self._conv3 = make_conv( 64, 128, kernel_size=5, stride=2)
self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)
self.refine_flow = RefineFlow(2 + 1 + 64)
self.refine_occ = RefineOcc(1 + 64 + 64)
initialize_msra(self.modules())
def forward(self, input_dict):
im1 = input_dict['input1']
im2 = input_dict['input2']
conv1_im1 = self._conv1(im1)
conv2_im1 = self._conv2(conv1_im1)
conv3_im1 = self._conv3(conv2_im1)
conv3_im1_wp = conv3_im1
conv1_im2 = self._conv1(im2)
conv2_im2 = self._conv2(conv1_im2)
conv3_im2 = self._conv3(conv2_im2)
conv3_im2_wp = conv3_im2
out_dict = {}
out_dict['flow'] = []
out_dict['flow1'] = []
out_dict['flow2'] = []
out_dict['flow3'] = []
out_dict['flow4'] = []
out_dict['flow5'] = []
out_dict['flow6'] = []
out_dict['occ'] = []
out_dict['occ1'] = []
out_dict['occ2'] = []
out_dict['occ3'] = []
out_dict['occ4'] = []
out_dict['occ5'] = []
out_dict['occ6'] = []
# warping:
_, _, height_im, width_im = im1.size()
# for iterative
for ii in range(0, self._num_iters):
flo2_f, flo3_f, flo4_f, flo5_f, flo6_f, occ2_f, occ3_f, occ4_f, occ5_f, occ6_f = self._flownets(conv2_im1,
conv3_im1,
conv3_im2_wp)
flo2_b, flo3_b, flo4_b, flo5_b, flo6_b, occ2_b, occ3_b, occ4_b, occ5_b, occ6_b = self._flownets(conv2_im2,
conv3_im2,
conv3_im1_wp)
if ii == 0:
out_dict['flow2'].append([flo2_f, flo2_b])
out_dict['flow3'].append([flo3_f, flo3_b])
out_dict['flow4'].append([flo4_f, flo4_b])
out_dict['flow5'].append([flo5_f, flo5_b])
out_dict['flow6'].append([flo6_f, flo6_b])
out_dict['occ2'].append([occ2_f, occ2_b])
out_dict['occ3'].append([occ3_f, occ3_b])
out_dict['occ4'].append([occ4_f, occ4_b])
out_dict['occ5'].append([occ5_f, occ5_b])
out_dict['occ6'].append([occ6_f, occ6_b])
flo2_f_out = flo2_f
flo2_b_out = flo2_b
occ2_f_out = occ2_f
occ2_b_out = occ2_b
else:
out_dict['flow2'].append([flo2_f + out_dict['flow2'][ii - 1][0], flo2_b + out_dict['flow2'][ii - 1][1]])
out_dict['flow3'].append([flo3_f + out_dict['flow3'][ii - 1][0], flo3_b + out_dict['flow3'][ii - 1][1]])
out_dict['flow4'].append([flo4_f + out_dict['flow4'][ii - 1][0], flo4_b + out_dict['flow4'][ii - 1][1]])
out_dict['flow5'].append([flo5_f + out_dict['flow5'][ii - 1][0], flo5_b + out_dict['flow5'][ii - 1][1]])
out_dict['flow6'].append([flo6_f + out_dict['flow6'][ii - 1][0], flo6_b + out_dict['flow6'][ii - 1][1]])
out_dict['occ2'].append([occ2_f + out_dict['occ2'][ii - 1][0], occ2_b + out_dict['occ2'][ii - 1][1]])
out_dict['occ3'].append([occ3_f + out_dict['occ3'][ii - 1][0], occ3_b + out_dict['occ3'][ii - 1][1]])
out_dict['occ4'].append([occ4_f + out_dict['occ4'][ii - 1][0], occ4_b + out_dict['occ4'][ii - 1][1]])
out_dict['occ5'].append([occ5_f + out_dict['occ5'][ii - 1][0], occ5_b + out_dict['occ5'][ii - 1][1]])
out_dict['occ6'].append([occ6_f + out_dict['occ6'][ii - 1][0], occ6_b + out_dict['occ6'][ii - 1][1]])
flo2_f_out = flo2_f + upsample2d_as(out_dict['flow1'][ii - 1][0], flo2_f, mode="bilinear")
flo2_b_out = flo2_b + upsample2d_as(out_dict['flow1'][ii - 1][1], flo2_b, mode="bilinear")
occ2_f_out = occ2_f + upsample2d_as(out_dict['occ1'][ii - 1][0], occ2_f, mode="bilinear")
occ2_b_out = occ2_b + upsample2d_as(out_dict['occ1'][ii - 1][1], occ2_b, mode="bilinear")
## refine layer
flo2_f_out = upsample2d_as(flo2_f_out, conv2_im1, mode="bilinear")
flo2_b_out = upsample2d_as(flo2_b_out, conv2_im2, mode="bilinear")
occ2_f_out = upsample2d_as(occ2_f_out, conv2_im1, mode="bilinear")
occ2_b_out = upsample2d_as(occ2_b_out, conv2_im2, mode="bilinear")
img1_resize = upsample2d_as(im1, flo2_f_out, mode="bilinear")
img2_resize = upsample2d_as(im2, flo2_b_out, mode="bilinear")
img2_warp = self._warping_layer(img2_resize, flo2_f_out, height_im, width_im, self._div_flow)
img1_warp = self._warping_layer(img1_resize, flo2_b_out, height_im, width_im, self._div_flow)
# flow refine
flow_f = self.refine_flow(flo2_f_out.detach(), img1_resize - img2_warp, conv2_im1)
flow_b = self.refine_flow(flo2_b_out.detach(), img2_resize - img1_warp, conv2_im2)
# occ refine
conv2_im2_warp = self._warping_layer(conv2_im2, flow_f, height_im, width_im, self._div_flow)
conv2_im1_warp = self._warping_layer(conv2_im1, flow_b, height_im, width_im, self._div_flow)
occ_f = self.refine_occ(occ2_f_out.detach(), conv2_im1, conv2_im1 - conv2_im2_warp)
occ_b = self.refine_occ(occ2_b_out.detach(), conv2_im2, conv2_im2 - conv2_im1_warp)
out_dict['flow1'].append([flow_f, flow_b])
out_dict['occ1'].append([occ_f, occ_b])
## upsample layer
flow_f = upsample2d_as(flow_f, im1, mode="bilinear")
flow_b = upsample2d_as(flow_b, im2, mode="bilinear")
out_dict['flow'].append([flow_f, flow_b])
im2_warp = self._warping_layer(im2, flow_f, height_im, width_im, self._div_flow)
im1_warp = self._warping_layer(im1, flow_b, height_im, width_im, self._div_flow)
flow_b_warp = self._warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)
flow_f_warp = self._warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)
occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([im1, im2_warp, flow_f, flow_b_warp], dim=1))
occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([im2, im1_warp, flow_b, flow_f_warp], dim=1))
out_dict['occ'].append([occ_f, occ_b])
if ii < (self._num_iters - 1):
flow_f_resized = upsample2d_as(flow_f, conv3_im2, mode="bilinear")
flow_b_resized = upsample2d_as(flow_b, conv3_im1, mode="bilinear")
conv3_im2_wp = self._warping_layer(conv3_im2, flow_f_resized, height_im, width_im, self._div_flow)
conv3_im1_wp = self._warping_layer(conv3_im1, flow_b_resized, height_im, width_im, self._div_flow)
if self.training:
return out_dict
else:
out_dict_eval = {}
out_dict_eval['flow'] = upsample2d_as(out_dict['flow'][self._num_iters - 1][0], im1, mode="bilinear") / self._div_flow
out_dict_eval['occ'] = upsample2d_as(out_dict['occ'][self._num_iters - 1][0], im1, mode="bilinear")
return out_dict_eval
================================================
FILE: models/IRR_PWC.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense, OccContextNetwork, OccEstimatorDense
from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc
import copy
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.dim_corr = (self.search_range * 2 + 1) ** 2
self.num_ch_in_flo = self.dim_corr + 32 + 2
self.num_ch_in_occ = self.dim_corr + 32 + 1
self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)
self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)
self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)
self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)
self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)
self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
conv(128, 32, kernel_size=1, stride=1, dilation=1),
conv(96, 32, kernel_size=1, stride=1, dilation=1),
conv(64, 32, kernel_size=1, stride=1, dilation=1)])
self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1)
self.refine_flow = RefineFlow(2 + 1 + 32)
self.refine_occ = RefineOcc(1 + 32 + 32)
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
batch_size, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
output_dict_eval = {}
flows = []
occs = []
_, _, h_x1, w_x1, = x1_pyramid[0].size()
flow_f = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()
flow_b = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()
occ_f = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()
occ_b = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
if l <= self.output_level:
# warping
if l == 0:
x2_warp = x2
x1_warp = x1
else:
flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
occ_f = upsample2d_as(occ_f, x1, mode="bilinear")
occ_b = upsample2d_as(occ_b, x2, mode="bilinear")
x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
# correlation
out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)
out_corr_relu_f = self.leakyRELU(out_corr_f)
out_corr_relu_b = self.leakyRELU(out_corr_b)
if l != self.output_level:
x1_1by1 = self.conv_1x1[l](x1)
x2_1by1 = self.conv_1x1[l](x2)
else:
x1_1by1 = x1
x2_1by1 = x2
# concat and estimate flow
flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)
flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)
x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))
x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))
flow_est_f = flow_f + flow_res_f
flow_est_b = flow_b + flow_res_b
flow_cont_f = flow_est_f + self.context_networks(torch.cat([x_intm_f, flow_est_f], dim=1))
flow_cont_b = flow_est_b + self.context_networks(torch.cat([x_intm_b, flow_est_b], dim=1))
# occ estimation
x_intm_occ_f, occ_res_f = self.occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, occ_f], dim=1))
x_intm_occ_b, occ_res_b = self.occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, occ_b], dim=1))
occ_est_f = occ_f + occ_res_f
occ_est_b = occ_b + occ_res_b
occ_cont_f = occ_est_f + self.occ_context_networks(torch.cat([x_intm_occ_f, occ_est_f], dim=1))
occ_cont_b = occ_est_b + self.occ_context_networks(torch.cat([x_intm_occ_b, occ_est_b], dim=1))
# refinement
img1_resize = upsample2d_as(x1_raw, flow_f, mode="bilinear")
img2_resize = upsample2d_as(x2_raw, flow_b, mode="bilinear")
img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
# flow refine
flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1)
flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1)
flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False)
flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False)
flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)
flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)
# occ refine
x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow)
x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow)
occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp)
occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp)
flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b])
occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b])
else:
flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
flows.append([flow_f, flow_b])
x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)
flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)
if l != self.num_levels-1:
x1_in = self.conv_1x1_1(x1)
x2_in = self.conv_1x1_1(x2)
x1_w_in = self.conv_1x1_1(x1_warp)
x2_w_in = self.conv_1x1_1(x2_warp)
else:
x1_in = x1
x2_in = x2
x1_w_in = x1_warp
x2_w_in = x2_warp
occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1))
occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1))
occs.append([occ_f, occ_b])
output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear")
output_dict['flow'] = flows
output_dict['occ'] = occs
if self.training:
return output_dict
else:
return output_dict_eval
================================================
FILE: models/__init__.py
================================================
from . import flownet1s
from . import flownet1s_irr
from . import flownet1s_irr_bi
from . import flownet1s_irr_occ
from . import flownet1s_irr_occ_bi
from . import IRR_FlowNet
from . import pwcnet
from . import pwcnet_bi
from . import pwcnet_occ
from . import pwcnet_occ_bi
from . import pwcnet_irr
from . import pwcnet_irr_bi
from . import pwcnet_irr_occ
from . import pwcnet_irr_occ_bi
from . import IRR_PWC
FlowNet1S = flownet1s.FlowNet1S
FlowNet1S_irr = flownet1s_irr.FlowNet1S
FlowNet1S_irr_bi = flownet1s_irr_bi.FlowNet1S
FlowNet1S_irr_occ = flownet1s_irr_occ.FlowNet1S
FlowNet1S_irr_occ_bi = flownet1s_irr_occ_bi.FlowNet1S
PWCNet = pwcnet.PWCNet
PWCNet_bi = pwcnet_bi.PWCNet
PWCNet_occ = pwcnet_occ.PWCNet
PWCNet_occ_bi = pwcnet_occ_bi.PWCNet
PWCNet_irr = pwcnet_irr.PWCNet
PWCNet_irr_bi = pwcnet_irr_bi.PWCNet
PWCNet_irr_occ = pwcnet_irr_occ.PWCNet
PWCNet_irr_occ_bi = pwcnet_irr_occ_bi.PWCNet
IRR_FlowNet = IRR_FlowNet.FlowNet1S
IRR_PWC = IRR_PWC.PWCNet
================================================
FILE: models/correlation_package/__init__.py
================================================
================================================
FILE: models/correlation_package/correlation.py
================================================
import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
import correlation_cuda
class CorrelationFunction(Function):
def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
super(CorrelationFunction, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply
# self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
def forward(self, input1, input2):
self.save_for_backward(input1, input2)
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
output = input1.new()
correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
return output
def backward(self, grad_output):
input1, input2 = self.saved_tensors
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
grad_input1 = input1.new()
grad_input2 = input2.new()
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
return grad_input1, grad_input2
class Correlation(Module):
def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
super(Correlation, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply
def forward(self, input1, input2):
result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)
return result
================================================
FILE: models/correlation_package/correlation_cuda.cc
================================================
#include
#include
#include
#include
#include "correlation_cuda_kernel.cuh"
int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply)
{
int batchSize = input1.size(0);
int nInputChannels = input1.size(1);
int inputHeight = input1.size(2);
int inputWidth = input1.size(3);
int kernel_radius = (kernel_size - 1) / 2;
int border_radius = kernel_radius + max_displacement;
int paddedInputHeight = inputHeight + 2 * pad_size;
int paddedInputWidth = inputWidth + 2 * pad_size;
int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
rInput1.fill_(0);
rInput2.fill_(0);
output.fill_(0);
int success = correlation_forward_cuda_kernel(
output,
output.size(0),
output.size(1),
output.size(2),
output.size(3),
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
input1,
input1.size(1),
input1.size(2),
input1.size(3),
input1.stride(0),
input1.stride(1),
input1.stride(2),
input1.stride(3),
input2,
input2.size(1),
input2.stride(0),
input2.stride(1),
input2.stride(2),
input2.stride(3),
rInput1,
rInput2,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply,
at::globalContext().getCurrentCUDAStream()
);
//check for errors
if (!success) {
AT_ERROR("CUDA call failed");
}
return 1;
}
int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
at::Tensor& gradInput1, at::Tensor& gradInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply)
{
int batchSize = input1.size(0);
int nInputChannels = input1.size(1);
int paddedInputHeight = input1.size(2)+ 2 * pad_size;
int paddedInputWidth = input1.size(3)+ 2 * pad_size;
int height = input1.size(2);
int width = input1.size(3);
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
gradInput1.resize_({batchSize, nInputChannels, height, width});
gradInput2.resize_({batchSize, nInputChannels, height, width});
rInput1.fill_(0);
rInput2.fill_(0);
gradInput1.fill_(0);
gradInput2.fill_(0);
int success = correlation_backward_cuda_kernel(gradOutput,
gradOutput.size(0),
gradOutput.size(1),
gradOutput.size(2),
gradOutput.size(3),
gradOutput.stride(0),
gradOutput.stride(1),
gradOutput.stride(2),
gradOutput.stride(3),
input1,
input1.size(1),
input1.size(2),
input1.size(3),
input1.stride(0),
input1.stride(1),
input1.stride(2),
input1.stride(3),
input2,
input2.stride(0),
input2.stride(1),
input2.stride(2),
input2.stride(3),
gradInput1,
gradInput1.stride(0),
gradInput1.stride(1),
gradInput1.stride(2),
gradInput1.stride(3),
gradInput2,
gradInput2.size(1),
gradInput2.stride(0),
gradInput2.stride(1),
gradInput2.stride(2),
gradInput2.stride(3),
rInput1,
rInput2,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply,
at::globalContext().getCurrentCUDAStream()
);
if (!success) {
AT_ERROR("CUDA call failed");
}
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
}
================================================
FILE: models/correlation_package/correlation_cuda_kernel.cu
================================================
#include
#include "correlation_cuda_kernel.cuh"
#define CUDA_NUM_THREADS 1024
#define THREADS_PER_BLOCK 32
#include
#include
#include
#include
using at::Half;
template
__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
{
// n (batch size), c (num of channels), y (height), x (width)
int n = blockIdx.x;
int y = blockIdx.y;
int x = blockIdx.z;
int ch_off = threadIdx.x;
scalar_t value;
int dimcyx = channels * height * width;
int dimyx = height * width;
int p_dimx = (width + 2 * pad_size);
int p_dimy = (height + 2 * pad_size);
int p_dimyxc = channels * p_dimy * p_dimx;
int p_dimxc = p_dimx * channels;
for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
value = input[n * dimcyx + c * dimyx + y * width + x];
rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
}
}
template
__global__ void correlation_forward(scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth,
const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,
const scalar_t* __restrict__ rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2)
{
// n (batch size), c (num of channels), y (height), x (width)
int pInputWidth = inputWidth + 2 * pad_size;
int pInputHeight = inputHeight + 2 * pad_size;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int n = blockIdx.x;
int y1 = blockIdx.y * stride1 + max_displacement;
int x1 = blockIdx.z * stride1 + max_displacement;
int c = threadIdx.x;
int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
int pdimxc = pInputWidth * nInputChannels;
int pdimc = nInputChannels;
int tdimcyx = nOutputChannels * outputHeight * outputWidth;
int tdimyx = outputHeight * outputWidth;
int tdimx = outputWidth;
scalar_t nelems = kernel_size * kernel_size * pdimc;
__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
// no significant speed-up in using chip memory for input1 sub-data,
// not enough chip memory size to accomodate memory per block for input2 sub-data
// instead i've used device memory for both
// element-wise product along channel axis
for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
prod_sum[c] = 0;
int x2 = x1 + ti*stride2;
int y2 = y1 + tj*stride2;
for (int j = -kernel_rad; j <= kernel_rad; ++j) {
for (int i = -kernel_rad; i <= kernel_rad; ++i) {
for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {
int indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch;
int indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch;
prod_sum[c] += rInput1[indx1] * rInput2[indx2];
}
}
}
// accumulate
__syncthreads();
if (c == 0) {
scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_PER_BLOCK; ++index) {
reduce_sum += prod_sum[index];
}
int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;
output[tindx] = reduce_sum / nelems;
}
}
}
}
template
__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
const scalar_t* __restrict__ rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2)
{
// n (batch size), c (num of channels), y (height), x (width)
int n = item;
int y = blockIdx.x * stride1 + pad_size;
int x = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int xmin = (x - kernel_rad - max_displacement) / stride1;
int ymin = (y - kernel_rad - max_displacement) / stride1;
int xmax = (x + kernel_rad - max_displacement) / stride1;
int ymax = (y + kernel_rad - max_displacement) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
// assumes gradInput1 is pre-allocated and zero filled
return;
}
if (xmin > xmax || ymin > ymax) {
// assumes gradInput1 is pre-allocated and zero filled
return;
}
xmin = max(0, xmin);
xmax = min(outputWidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(outputHeight - 1, ymax);
int pInputWidth = inputWidth + 2 * pad_size;
int pInputHeight = inputHeight + 2 * pad_size;
int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
int pdimxc = pInputWidth * nInputChannels;
int pdimc = nInputChannels;
int tdimcyx = nOutputChannels * outputHeight * outputWidth;
int tdimyx = outputHeight * outputWidth;
int tdimx = outputWidth;
int odimcyx = nInputChannels * inputHeight* inputWidth;
int odimyx = inputHeight * inputWidth;
int odimx = inputWidth;
scalar_t nelems = kernel_size * kernel_size * nInputChannels;
__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
scalar_t val2 = rInput2[indx2];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
prod_sum[tch_off] += gradOutput[tindx] * val2;
}
}
}
__syncthreads();
if (tch_off == 0) {
scalar_t reduce_sum = 0;
for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
reduce_sum += prod_sum[idx];
}
const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
gradInput1[indx1] = reduce_sum / nelems;
}
}
template
__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
const scalar_t* __restrict__ rInput1,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2)
{
// n (batch size), c (num of channels), y (height), x (width)
int n = item;
int y = blockIdx.x * stride1 + pad_size;
int x = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int pInputWidth = inputWidth + 2 * pad_size;
int pInputHeight = inputHeight + 2 * pad_size;
int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
int pdimxc = pInputWidth * nInputChannels;
int pdimc = nInputChannels;
int tdimcyx = nOutputChannels * outputHeight * outputWidth;
int tdimyx = outputHeight * outputWidth;
int tdimx = outputWidth;
int odimcyx = nInputChannels * inputHeight* inputWidth;
int odimyx = inputHeight * inputWidth;
int odimx = inputWidth;
scalar_t nelems = kernel_size * kernel_size * nInputChannels;
__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
// assumes gradInput2 is pre-allocated and zero filled
continue;
}
if (xmin > xmax || ymin > ymax) {
// assumes gradInput2 is pre-allocated and zero filled
continue;
}
xmin = max(0, xmin);
xmax = min(outputWidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(outputHeight - 1, ymax);
int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
scalar_t val1 = rInput1[indx1];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
prod_sum[tch_off] += gradOutput[tindx] * val1;
}
}
}
__syncthreads();
if (tch_off == 0) {
scalar_t reduce_sum = 0;
for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
reduce_sum += prod_sum[idx];
}
const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
gradInput2[indx2] = reduce_sum / nelems;
}
}
int correlation_forward_cuda_kernel(at::Tensor& output,
int ob,
int oc,
int oh,
int ow,
int osb,
int osc,
int osh,
int osw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gc,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream)
{
int batchSize = ob;
int nInputChannels = ic;
int inputWidth = iw;
int inputHeight = ih;
int nOutputChannels = oc;
int outputWidth = ow;
int outputHeight = oh;
dim3 blocks_grid(batchSize, inputHeight, inputWidth);
dim3 threads_block(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
channels_first << > >(
input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
channels_first << > > (
input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
}));
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
correlation_forward << > >
(output.data(), nOutputChannels, outputHeight, outputWidth,
rInput1.data(), nInputChannels, inputHeight, inputWidth,
rInput2.data(),
pad_size,
kernel_size,
max_displacement,
stride1,
stride2);
}));
cudaError_t err = cudaGetLastError();
// check for errors
if (err != cudaSuccess) {
printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
return 0;
}
return 1;
}
int correlation_backward_cuda_kernel(
at::Tensor& gradOutput,
int gob,
int goc,
int goh,
int gow,
int gosb,
int gosc,
int gosh,
int gosw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& gradInput1,
int gisb,
int gisc,
int gish,
int gisw,
at::Tensor& gradInput2,
int ggc,
int ggsb,
int ggsc,
int ggsh,
int ggsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream)
{
int batchSize = gob;
int num = batchSize;
int nInputChannels = ic;
int inputWidth = iw;
int inputHeight = ih;
int nOutputChannels = goc;
int outputWidth = gow;
int outputHeight = goh;
dim3 blocks_grid(batchSize, inputHeight, inputWidth);
dim3 threads_block(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
channels_first << > >(
input1.data(),
rInput1.data(),
nInputChannels,
inputHeight,
inputWidth,
pad_size
);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
channels_first << > >(
input2.data(),
rInput2.data(),
nInputChannels,
inputHeight,
inputWidth,
pad_size
);
}));
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
for (int n = 0; n < num; ++n) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
correlation_backward_input1 << > > (
n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
rInput2.data(),
pad_size,
kernel_size,
max_displacement,
stride1,
stride2);
}));
}
for (int n = 0; n < batchSize; n++) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
correlation_backward_input2 << > >(
n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
rInput1.data(),
pad_size,
kernel_size,
max_displacement,
stride1,
stride2);
}));
}
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
return 0;
}
return 1;
}
================================================
FILE: models/correlation_package/correlation_cuda_kernel.cuh
================================================
#pragma once
#include
#include
#include
int correlation_forward_cuda_kernel(at::Tensor& output,
int ob,
int oc,
int oh,
int ow,
int osb,
int osc,
int osh,
int osw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gc,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream);
int correlation_backward_cuda_kernel(
at::Tensor& gradOutput,
int gob,
int goc,
int goh,
int gow,
int gosb,
int gosc,
int gosh,
int gosw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& gradInput1,
int gisb,
int gisc,
int gish,
int gisw,
at::Tensor& gradInput2,
int ggc,
int ggsb,
int ggsc,
int ggsh,
int ggsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream);
================================================
FILE: models/correlation_package/setup.py
================================================
#!/usr/bin/env python3
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_50,code=sm_50',
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_61,code=compute_61'
]
setup(
name='correlation_cuda',
ext_modules=[
CUDAExtension('correlation_cuda', [
'correlation_cuda.cc',
'correlation_cuda_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})
================================================
FILE: models/correlation_package_cu9/__init__.py
================================================
================================================
FILE: models/correlation_package_cu9/correlation.py
================================================
import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
import correlation_cuda
class CorrelationFunction(Function):
def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
super(CorrelationFunction, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply
# self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
def forward(self, input1, input2):
self.save_for_backward(input1, input2)
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
output = input1.new()
correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
return output
def backward(self, grad_output):
input1, input2 = self.saved_tensors
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
grad_input1 = input1.new()
grad_input2 = input2.new()
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
return grad_input1, grad_input2
class Correlation(Module):
def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
super(Correlation, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply
def forward(self, input1, input2):
result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)
return result
================================================
FILE: models/correlation_package_cu9/correlation_cuda.cc
================================================
#include
#include
#include
#include
#include
#include
#include "correlation_cuda_kernel.cuh"
int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply)
{
int batchSize = input1.size(0);
int nInputChannels = input1.size(1);
int inputHeight = input1.size(2);
int inputWidth = input1.size(3);
int kernel_radius = (kernel_size - 1) / 2;
int border_radius = kernel_radius + max_displacement;
int paddedInputHeight = inputHeight + 2 * pad_size;
int paddedInputWidth = inputWidth + 2 * pad_size;
int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
rInput1.fill_(0);
rInput2.fill_(0);
output.fill_(0);
int success = correlation_forward_cuda_kernel(
output,
output.size(0),
output.size(1),
output.size(2),
output.size(3),
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
input1,
input1.size(1),
input1.size(2),
input1.size(3),
input1.stride(0),
input1.stride(1),
input1.stride(2),
input1.stride(3),
input2,
input2.size(1),
input2.stride(0),
input2.stride(1),
input2.stride(2),
input2.stride(3),
rInput1,
rInput2,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply,
at::cuda::getCurrentCUDAStream()
//at::globalContext().getCurrentCUDAStream()
);
//check for errors
if (!success) {
AT_ERROR("CUDA call failed");
}
return 1;
}
int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
at::Tensor& gradInput1, at::Tensor& gradInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply)
{
int batchSize = input1.size(0);
int nInputChannels = input1.size(1);
int paddedInputHeight = input1.size(2)+ 2 * pad_size;
int paddedInputWidth = input1.size(3)+ 2 * pad_size;
int height = input1.size(2);
int width = input1.size(3);
rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
gradInput1.resize_({batchSize, nInputChannels, height, width});
gradInput2.resize_({batchSize, nInputChannels, height, width});
rInput1.fill_(0);
rInput2.fill_(0);
gradInput1.fill_(0);
gradInput2.fill_(0);
int success = correlation_backward_cuda_kernel(gradOutput,
gradOutput.size(0),
gradOutput.size(1),
gradOutput.size(2),
gradOutput.size(3),
gradOutput.stride(0),
gradOutput.stride(1),
gradOutput.stride(2),
gradOutput.stride(3),
input1,
input1.size(1),
input1.size(2),
input1.size(3),
input1.stride(0),
input1.stride(1),
input1.stride(2),
input1.stride(3),
input2,
input2.stride(0),
input2.stride(1),
input2.stride(2),
input2.stride(3),
gradInput1,
gradInput1.stride(0),
gradInput1.stride(1),
gradInput1.stride(2),
gradInput1.stride(3),
gradInput2,
gradInput2.size(1),
gradInput2.stride(0),
gradInput2.stride(1),
gradInput2.stride(2),
gradInput2.stride(3),
rInput1,
rInput2,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply,
at::cuda::getCurrentCUDAStream()
//at::globalContext().getCurrentCUDAStream()
);
if (!success) {
AT_ERROR("CUDA call failed");
}
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
}
================================================
FILE: models/correlation_package_cu9/correlation_cuda_kernel.cu
================================================
#include
#include "correlation_cuda_kernel.cuh"
#define CUDA_NUM_THREADS 1024
#define THREADS_PER_BLOCK 32
#include
#include
#include
#include
using at::Half;
template
__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
{
// n (batch size), c (num of channels), y (height), x (width)
int n = blockIdx.x;
int y = blockIdx.y;
int x = blockIdx.z;
int ch_off = threadIdx.x;
scalar_t value;
int dimcyx = channels * height * width;
int dimyx = height * width;
int p_dimx = (width + 2 * pad_size);
int p_dimy = (height + 2 * pad_size);
int p_dimyxc = channels * p_dimy * p_dimx;
int p_dimxc = p_dimx * channels;
for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
value = input[n * dimcyx + c * dimyx + y * width + x];
rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
}
}
template
__global__ void correlation_forward(scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth,
const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,
const scalar_t* __restrict__ rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2)
{
// n (batch size), c (num of channels), y (height), x (width)
int pInputWidth = inputWidth + 2 * pad_size;
int pInputHeight = inputHeight + 2 * pad_size;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int n = blockIdx.x;
int y1 = blockIdx.y * stride1 + max_displacement;
int x1 = blockIdx.z * stride1 + max_displacement;
int c = threadIdx.x;
int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
int pdimxc = pInputWidth * nInputChannels;
int pdimc = nInputChannels;
int tdimcyx = nOutputChannels * outputHeight * outputWidth;
int tdimyx = outputHeight * outputWidth;
int tdimx = outputWidth;
scalar_t nelems = kernel_size * kernel_size * pdimc;
__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
// no significant speed-up in using chip memory for input1 sub-data,
// not enough chip memory size to accomodate memory per block for input2 sub-data
// instead i've used device memory for both
// element-wise product along channel axis
for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
prod_sum[c] = 0;
int x2 = x1 + ti*stride2;
int y2 = y1 + tj*stride2;
for (int j = -kernel_rad; j <= kernel_rad; ++j) {
for (int i = -kernel_rad; i <= kernel_rad; ++i) {
for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {
int indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch;
int indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch;
prod_sum[c] += rInput1[indx1] * rInput2[indx2];
}
}
}
// accumulate
__syncthreads();
if (c == 0) {
scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_PER_BLOCK; ++index) {
reduce_sum += prod_sum[index];
}
int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;
output[tindx] = reduce_sum / nelems;
}
}
}
}
template
__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
const scalar_t* __restrict__ rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2)
{
// n (batch size), c (num of channels), y (height), x (width)
int n = item;
int y = blockIdx.x * stride1 + pad_size;
int x = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int xmin = (x - kernel_rad - max_displacement) / stride1;
int ymin = (y - kernel_rad - max_displacement) / stride1;
int xmax = (x + kernel_rad - max_displacement) / stride1;
int ymax = (y + kernel_rad - max_displacement) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
// assumes gradInput1 is pre-allocated and zero filled
return;
}
if (xmin > xmax || ymin > ymax) {
// assumes gradInput1 is pre-allocated and zero filled
return;
}
xmin = max(0, xmin);
xmax = min(outputWidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(outputHeight - 1, ymax);
int pInputWidth = inputWidth + 2 * pad_size;
int pInputHeight = inputHeight + 2 * pad_size;
int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
int pdimxc = pInputWidth * nInputChannels;
int pdimc = nInputChannels;
int tdimcyx = nOutputChannels * outputHeight * outputWidth;
int tdimyx = outputHeight * outputWidth;
int tdimx = outputWidth;
int odimcyx = nInputChannels * inputHeight* inputWidth;
int odimyx = inputHeight * inputWidth;
int odimx = inputWidth;
scalar_t nelems = kernel_size * kernel_size * nInputChannels;
__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
scalar_t val2 = rInput2[indx2];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
prod_sum[tch_off] += gradOutput[tindx] * val2;
}
}
}
__syncthreads();
if (tch_off == 0) {
scalar_t reduce_sum = 0;
for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
reduce_sum += prod_sum[idx];
}
const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
gradInput1[indx1] = reduce_sum / nelems;
}
}
template
__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
const scalar_t* __restrict__ rInput1,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2)
{
// n (batch size), c (num of channels), y (height), x (width)
int n = item;
int y = blockIdx.x * stride1 + pad_size;
int x = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int pInputWidth = inputWidth + 2 * pad_size;
int pInputHeight = inputHeight + 2 * pad_size;
int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
int pdimxc = pInputWidth * nInputChannels;
int pdimc = nInputChannels;
int tdimcyx = nOutputChannels * outputHeight * outputWidth;
int tdimyx = outputHeight * outputWidth;
int tdimx = outputWidth;
int odimcyx = nInputChannels * inputHeight* inputWidth;
int odimyx = inputHeight * inputWidth;
int odimx = inputWidth;
scalar_t nelems = kernel_size * kernel_size * nInputChannels;
__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
// assumes gradInput2 is pre-allocated and zero filled
continue;
}
if (xmin > xmax || ymin > ymax) {
// assumes gradInput2 is pre-allocated and zero filled
continue;
}
xmin = max(0, xmin);
xmax = min(outputWidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(outputHeight - 1, ymax);
int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
scalar_t val1 = rInput1[indx1];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
prod_sum[tch_off] += gradOutput[tindx] * val1;
}
}
}
__syncthreads();
if (tch_off == 0) {
scalar_t reduce_sum = 0;
for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
reduce_sum += prod_sum[idx];
}
const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
gradInput2[indx2] = reduce_sum / nelems;
}
}
int correlation_forward_cuda_kernel(at::Tensor& output,
int ob,
int oc,
int oh,
int ow,
int osb,
int osc,
int osh,
int osw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gc,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream)
{
int batchSize = ob;
int nInputChannels = ic;
int inputWidth = iw;
int inputHeight = ih;
int nOutputChannels = oc;
int outputWidth = ow;
int outputHeight = oh;
dim3 blocks_grid(batchSize, inputHeight, inputWidth);
dim3 threads_block(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
channels_first << > >(
input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
channels_first << > > (
input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
}));
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
correlation_forward << > >
(output.data(), nOutputChannels, outputHeight, outputWidth,
rInput1.data(), nInputChannels, inputHeight, inputWidth,
rInput2.data(),
pad_size,
kernel_size,
max_displacement,
stride1,
stride2);
}));
cudaError_t err = cudaGetLastError();
// check for errors
if (err != cudaSuccess) {
printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
return 0;
}
return 1;
}
int correlation_backward_cuda_kernel(
at::Tensor& gradOutput,
int gob,
int goc,
int goh,
int gow,
int gosb,
int gosc,
int gosh,
int gosw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& gradInput1,
int gisb,
int gisc,
int gish,
int gisw,
at::Tensor& gradInput2,
int ggc,
int ggsb,
int ggsc,
int ggsh,
int ggsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream)
{
int batchSize = gob;
int num = batchSize;
int nInputChannels = ic;
int inputWidth = iw;
int inputHeight = ih;
int nOutputChannels = goc;
int outputWidth = gow;
int outputHeight = goh;
dim3 blocks_grid(batchSize, inputHeight, inputWidth);
dim3 threads_block(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
channels_first << > >(
input1.data(),
rInput1.data(),
nInputChannels,
inputHeight,
inputWidth,
pad_size
);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
channels_first << > >(
input2.data(),
rInput2.data(),
nInputChannels,
inputHeight,
inputWidth,
pad_size
);
}));
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
for (int n = 0; n < num; ++n) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
correlation_backward_input1 << > > (
n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
rInput2.data(),
pad_size,
kernel_size,
max_displacement,
stride1,
stride2);
}));
}
for (int n = 0; n < batchSize; n++) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
correlation_backward_input2 << > >(
n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
rInput1.data(),
pad_size,
kernel_size,
max_displacement,
stride1,
stride2);
}));
}
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
return 0;
}
return 1;
}
================================================
FILE: models/correlation_package_cu9/correlation_cuda_kernel.cuh
================================================
#pragma once
#include
#include
#include
int correlation_forward_cuda_kernel(at::Tensor& output,
int ob,
int oc,
int oh,
int ow,
int osb,
int osc,
int osh,
int osw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gc,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream);
int correlation_backward_cuda_kernel(
at::Tensor& gradOutput,
int gob,
int goc,
int goh,
int gow,
int gosb,
int gosc,
int gosh,
int gosw,
at::Tensor& input1,
int ic,
int ih,
int iw,
int isb,
int isc,
int ish,
int isw,
at::Tensor& input2,
int gsb,
int gsc,
int gsh,
int gsw,
at::Tensor& gradInput1,
int gisb,
int gisc,
int gish,
int gisw,
at::Tensor& gradInput2,
int ggc,
int ggsb,
int ggsc,
int ggsh,
int ggsw,
at::Tensor& rInput1,
at::Tensor& rInput2,
int pad_size,
int kernel_size,
int max_displacement,
int stride1,
int stride2,
int corr_type_multiply,
cudaStream_t stream);
================================================
FILE: models/correlation_package_cu9/setup.py
================================================
#!/usr/bin/env python3
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_50,code=sm_50',
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_61,code=compute_61',
'-ccbin', '/usr/bin/gcc-4.9'
]
setup(
name='correlation_cuda',
ext_modules=[
CUDAExtension('correlation_cuda', [
'correlation_cuda.cc',
'correlation_cuda_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args, 'cuda-path': ['/usr/local/cuda-9.0']})
],
cmdclass={
'build_ext': BuildExtension
})
================================================
FILE: models/flownet1s.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .flownet_modules import conv, deconv
from .flownet_modules import concatenate_as, upsample2d_as
from .flownet_modules import initialize_msra
class FlowNetS(nn.Module):
def __init__(self, args):
super(FlowNetS, self).__init__()
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv1 = make_conv( 6, 64, kernel_size=7, stride=2)
self._conv2 = make_conv( 64, 128, kernel_size=5, stride=2)
self._conv3 = make_conv( 128, 256, kernel_size=5, stride=2)
self._conv3_1 = make_conv( 256, 256, kernel_size=3, stride=1)
self._conv4 = make_conv( 256, 512, kernel_size=3, stride=2)
self._conv4_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv5 = make_conv( 512, 512, kernel_size=3, stride=2)
self._conv5_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv6 = make_conv( 512, 1024, kernel_size=3, stride=2)
self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)
def make_deconv(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=True, bias=False)
self._deconv5 = make_deconv(1024 , 512)
self._deconv4 = make_deconv(1024 + 2, 256)
self._deconv3 = make_deconv( 768 + 2, 128)
self._deconv2 = make_deconv( 384 + 2, 64)
def make_predict(in_planes, out_planes):
return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,
nonlinear=False, bias=True)
self._predict_flow6 = make_predict(1024 , 2)
self._predict_flow5 = make_predict(1024 + 2, 2)
self._predict_flow4 = make_predict( 768 + 2, 2)
self._predict_flow3 = make_predict( 384 + 2, 2)
self._predict_flow2 = make_predict( 192 + 2, 2)
def make_upsample(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=False, bias=False)
self._upsample_flow6_to_5 = make_upsample(2, 2)
self._upsample_flow5_to_4 = make_upsample(2, 2)
self._upsample_flow4_to_3 = make_upsample(2, 2)
self._upsample_flow3_to_2 = make_upsample(2, 2)
initialize_msra(self.modules())
def forward(self, inputs):
conv1 = self._conv1(inputs)
conv2 = self._conv2(conv1)
conv3_1 = self._conv3_1(self._conv3(conv2))
conv4_1 = self._conv4_1(self._conv4(conv3_1))
conv5_1 = self._conv5_1(self._conv5(conv4_1))
conv6_1 = self._conv6_1(self._conv6(conv5_1))
predict_flow6 = self._predict_flow6(conv6_1)
upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)
deconv5 = self._deconv5(conv6_1)
concat5 = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)
predict_flow5 = self._predict_flow5(concat5)
upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)
deconv4 = self._deconv4(concat5)
concat4 = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)
predict_flow4 = self._predict_flow4(concat4)
upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)
deconv3 = self._deconv3(concat4)
concat3 = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)
predict_flow3 = self._predict_flow3(concat3)
upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)
deconv2 = self._deconv2(concat3)
concat2 = concatenate_as((conv2, deconv2, upsampled_flow3_to_2), conv2, dim=1)
predict_flow2 = self._predict_flow2(concat2)
if self.training:
return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6
else:
return predict_flow2
class FlowNet1S(nn.Module):
def __init__(self, args, div_flow=0.05):
super(FlowNet1S, self).__init__()
self._flownets = FlowNetS(args)
self._div_flow = div_flow
def forward(self, input_dict):
im1 = input_dict['input1']
im2 = input_dict['input2']
inputs = torch.cat((im1, im2), dim=1)
output_dict = {}
if self.training:
flow2, flow3, flow4, flow5, flow6 = self._flownets(inputs)
output_dict['flow2'] = flow2
output_dict['flow3'] = flow3
output_dict['flow4'] = flow4
output_dict['flow5'] = flow5
output_dict['flow6'] = flow6
else:
flow2 = self._flownets(inputs)
output_dict['flow1'] = (1.0 / self._div_flow) * upsample2d_as(flow2, im1, mode="bilinear")
return output_dict
================================================
FILE: models/flownet1s_irr.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .flownet_modules import conv, deconv
from .flownet_modules import concatenate_as, upsample2d_as
from .flownet_modules import initialize_msra
from .flownet_modules import WarpingLayer
class FlowNetS(nn.Module):
def __init__(self, args):
super(FlowNetS, self).__init__()
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv3_1 = make_conv( 256, 256, kernel_size=3, stride=1)
self._conv4 = make_conv( 256, 512, kernel_size=3, stride=2)
self._conv4_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv5 = make_conv( 512, 512, kernel_size=3, stride=2)
self._conv5_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv6 = make_conv( 512, 1024, kernel_size=3, stride=2)
self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)
def make_deconv(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=True, bias=False)
self._deconv5 = make_deconv(1024 , 512)
self._deconv4 = make_deconv(1024 + 2, 256)
self._deconv3 = make_deconv( 768 + 2, 128)
self._deconv2 = make_deconv( 384 + 2, 64)
def make_predict(in_planes, out_planes):
return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,
nonlinear=False, bias=True)
self._predict_flow6 = make_predict(1024 , 2)
self._predict_flow5 = make_predict(1024 + 2, 2)
self._predict_flow4 = make_predict( 768 + 2, 2)
self._predict_flow3 = make_predict( 384 + 2, 2)
self._predict_flow2 = make_predict( 128 + 2, 2)
def make_upsample(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=False, bias=False)
self._upsample_flow6_to_5 = make_upsample(2, 2)
self._upsample_flow5_to_4 = make_upsample(2, 2)
self._upsample_flow4_to_3 = make_upsample(2, 2)
self._upsample_flow3_to_2 = make_upsample(2, 2)
def forward(self, conv2_im1, conv3_im1, conv3_im2):
conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)
conv3_1 = self._conv3_1(conv_concat3)
conv4_1 = self._conv4_1(self._conv4(conv3_1))
conv5_1 = self._conv5_1(self._conv5(conv4_1))
conv6_1 = self._conv6_1(self._conv6(conv5_1))
predict_flow6 = self._predict_flow6(conv6_1)
upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)
deconv5 = self._deconv5(conv6_1)
concat5 = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)
predict_flow5 = self._predict_flow5(concat5)
upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)
deconv4 = self._deconv4(concat5)
concat4 = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)
predict_flow4 = self._predict_flow4(concat4)
upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)
deconv3 = self._deconv3(concat4)
concat3 = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)
predict_flow3 = self._predict_flow3(concat3)
upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)
deconv2 = self._deconv2(concat3)
concat2 = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)
predict_flow2 = self._predict_flow2(concat2)
return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6
class FlowNet1S(nn.Module):
def __init__(self, args, div_flow=0.05):
super(FlowNet1S, self).__init__()
self._flownets = FlowNetS(args)
self._warping_layer = WarpingLayer()
self._div_flow = div_flow
self._num_iters = args.num_iters
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv1 = make_conv( 3, 32, kernel_size=7, stride=2)
self._conv2 = make_conv( 32, 64, kernel_size=5, stride=2)
self._conv3 = make_conv( 64, 128, kernel_size=5, stride=2)
initialize_msra(self.modules())
def forward(self, input_dict):
im1 = input_dict['input1']
im2 = input_dict['input2']
conv1_im1 = self._conv1(im1)
conv2_im1 = self._conv2(conv1_im1)
conv3_im1 = self._conv3(conv2_im1)
conv1_im2 = self._conv1(im2)
conv2_im2 = self._conv2(conv1_im2)
conv3_im2_orig = self._conv3(conv2_im2)
conv3_im2 = conv3_im2_orig
output_dict = {}
output_dict['flow2'] = []
output_dict['flow3'] = []
output_dict['flow4'] = []
output_dict['flow5'] = []
output_dict['flow6'] = []
_, _, height_im, width_im = im1.size()
# for iterative
for ii in range(0, self._num_iters):
flow2, flow3, flow4, flow5, flow6 = self._flownets(conv2_im1, conv3_im1, conv3_im2)
if ii == 0:
output_dict['flow2'].append(flow2)
output_dict['flow3'].append(flow3)
output_dict['flow4'].append(flow4)
output_dict['flow5'].append(flow5)
output_dict['flow6'].append(flow6)
else:
output_dict['flow2'].append(flow2 + output_dict['flow2'][ii - 1])
output_dict['flow3'].append(flow3 + output_dict['flow3'][ii - 1])
output_dict['flow4'].append(flow4 + output_dict['flow4'][ii - 1])
output_dict['flow5'].append(flow5 + output_dict['flow5'][ii - 1])
output_dict['flow6'].append(flow6 + output_dict['flow6'][ii - 1])
if ii < (self._num_iters - 1):
up_flow = upsample2d_as(output_dict['flow2'][ii], conv3_im2_orig, mode="bilinear")
conv3_im2 = self._warping_layer(conv3_im2_orig, up_flow, height_im, width_im, self._div_flow)
if self.training:
return output_dict
else:
output_dict_eval = {}
up_flow_final = upsample2d_as(output_dict['flow2'][self._num_iters - 1], im1, mode="bilinear")
output_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final
return output_dict_eval
================================================
FILE: models/flownet1s_irr_bi.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .flownet_modules import conv, deconv
from .flownet_modules import concatenate_as, upsample2d_as
from .flownet_modules import initialize_msra
from .flownet_modules import WarpingLayer
class FlowNetS(nn.Module):
def __init__(self, args):
super(FlowNetS, self).__init__()
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv3_1 = make_conv( 256, 256, kernel_size=3, stride=1)
self._conv4 = make_conv( 256, 512, kernel_size=3, stride=2)
self._conv4_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv5 = make_conv( 512, 512, kernel_size=3, stride=2)
self._conv5_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv6 = make_conv( 512, 1024, kernel_size=3, stride=2)
self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)
def make_deconv(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=True, bias=False)
self._deconv5 = make_deconv(1024 , 512)
self._deconv4 = make_deconv(1024 + 2, 256)
self._deconv3 = make_deconv( 768 + 2, 128)
self._deconv2 = make_deconv( 384 + 2, 64)
def make_predict(in_planes, out_planes):
return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,
nonlinear=False, bias=True)
self._predict_flow6 = make_predict(1024 , 2)
self._predict_flow5 = make_predict(1024 + 2, 2)
self._predict_flow4 = make_predict( 768 + 2, 2)
self._predict_flow3 = make_predict( 384 + 2, 2)
self._predict_flow2 = make_predict( 128 + 2, 2)
def make_upsample(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=False, bias=False)
self._upsample_flow6_to_5 = make_upsample(2, 2)
self._upsample_flow5_to_4 = make_upsample(2, 2)
self._upsample_flow4_to_3 = make_upsample(2, 2)
self._upsample_flow3_to_2 = make_upsample(2, 2)
def forward(self, conv2_im1, conv3_im1, conv3_im2):
conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)
conv3_1 = self._conv3_1(conv_concat3)
conv4_1 = self._conv4_1(self._conv4(conv3_1))
conv5_1 = self._conv5_1(self._conv5(conv4_1))
conv6_1 = self._conv6_1(self._conv6(conv5_1))
# Flow Decoder
predict_flow6 = self._predict_flow6(conv6_1)
upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)
deconv5 = self._deconv5(conv6_1)
concat5 = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)
predict_flow5 = self._predict_flow5(concat5)
upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)
deconv4 = self._deconv4(concat5)
concat4 = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)
predict_flow4 = self._predict_flow4(concat4)
upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)
deconv3 = self._deconv3(concat4)
concat3 = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)
predict_flow3 = self._predict_flow3(concat3)
upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)
deconv2 = self._deconv2(concat3)
concat2 = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)
predict_flow2 = self._predict_flow2(concat2)
return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6
class FlowNet1S(nn.Module):
def __init__(self, args, div_flow=0.05):
super(FlowNet1S, self).__init__()
self._flownets = FlowNetS(args)
self._warping_layer = WarpingLayer()
self._div_flow = div_flow
self._num_iters = args.num_iters
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv1 = make_conv( 3, 32, kernel_size=7, stride=2)
self._conv2 = make_conv( 32, 64, kernel_size=5, stride=2)
self._conv3 = make_conv( 64, 128, kernel_size=5, stride=2)
initialize_msra(self.modules())
def forward(self, input_dict):
im1 = input_dict['input1']
im2 = input_dict['input2']
conv1_im1 = self._conv1(im1)
conv2_im1 = self._conv2(conv1_im1)
conv3_im1 = self._conv3(conv2_im1)
conv2_im1_wp = conv2_im1
conv3_im1_wp = conv3_im1
conv1_im2 = self._conv1(im2)
conv2_im2 = self._conv2(conv1_im2)
conv3_im2 = self._conv3(conv2_im2)
conv2_im2_wp = conv2_im2
conv3_im2_wp = conv3_im2
out_dict = {}
out_dict['flow2'] = []
out_dict['flow3'] = []
out_dict['flow4'] = []
out_dict['flow5'] = []
out_dict['flow6'] = []
_, _, height_im, width_im = im1.size()
# for iterative
for ii in range(0, self._num_iters):
flo2_f, flo3_f, flo4_f, flo5_f, flo6_f = self._flownets(conv2_im1, conv3_im1, conv3_im2_wp)
flo2_b, flo3_b, flo4_b, flo5_b, flo6_b = self._flownets(conv2_im2, conv3_im2, conv3_im1_wp)
if ii == 0:
out_dict['flow2'].append([flo2_f, flo2_b])
out_dict['flow3'].append([flo3_f, flo3_b])
out_dict['flow4'].append([flo4_f, flo4_b])
out_dict['flow5'].append([flo5_f, flo5_b])
out_dict['flow6'].append([flo6_f, flo6_b])
else:
out_dict['flow2'].append([flo2_f + out_dict['flow2'][ii - 1][0], flo2_b + out_dict['flow2'][ii - 1][1]])
out_dict['flow3'].append([flo3_f + out_dict['flow3'][ii - 1][0], flo3_b + out_dict['flow3'][ii - 1][1]])
out_dict['flow4'].append([flo4_f + out_dict['flow4'][ii - 1][0], flo4_b + out_dict['flow4'][ii - 1][1]])
out_dict['flow5'].append([flo5_f + out_dict['flow5'][ii - 1][0], flo5_b + out_dict['flow5'][ii - 1][1]])
out_dict['flow6'].append([flo6_f + out_dict['flow6'][ii - 1][0], flo6_b + out_dict['flow6'][ii - 1][1]])
if ii < (self._num_iters - 1):
up_flow_f_c3 = upsample2d_as(out_dict['flow2'][ii][0], conv3_im2, mode="bilinear")
up_flow_b_c3 = upsample2d_as(out_dict['flow2'][ii][1], conv3_im1, mode="bilinear")
conv3_im2_wp = self._warping_layer(conv3_im2, up_flow_f_c3, height_im, width_im, self._div_flow)
conv3_im1_wp = self._warping_layer(conv3_im1, up_flow_b_c3, height_im, width_im, self._div_flow)
if self.training:
return out_dict
else:
out_dict_eval = {}
up_flow_final = upsample2d_as(out_dict['flow2'][self._num_iters - 1][0], im1, mode="bilinear")
out_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final
return out_dict_eval
================================================
FILE: models/flownet1s_irr_occ.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .flownet_modules import conv, deconv
from .flownet_modules import concatenate_as, upsample2d_as
from .flownet_modules import initialize_msra
from .flownet_modules import WarpingLayer
class FlowNetS(nn.Module):
def __init__(self, args):
super(FlowNetS, self).__init__()
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv3_1 = make_conv( 256, 256, kernel_size=3, stride=1)
self._conv4 = make_conv( 256, 512, kernel_size=3, stride=2)
self._conv4_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv5 = make_conv( 512, 512, kernel_size=3, stride=2)
self._conv5_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv6 = make_conv( 512, 1024, kernel_size=3, stride=2)
self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)
def make_deconv(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=True, bias=False)
self._deconv5 = make_deconv(1024 , 512)
self._deconv4 = make_deconv(1024 + 2, 256)
self._deconv3 = make_deconv( 768 + 2, 128)
self._deconv2 = make_deconv( 384 + 2, 64)
self._deconv_occ5 = make_deconv(1024 , 512)
self._deconv_occ4 = make_deconv(1024 + 1, 256)
self._deconv_occ3 = make_deconv( 768 + 1, 128)
self._deconv_occ2 = make_deconv( 384 + 1, 64)
def make_predict(in_planes, out_planes):
return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,
nonlinear=False, bias=True)
self._predict_flow6 = make_predict(1024 , 2)
self._predict_flow5 = make_predict(1024 + 2, 2)
self._predict_flow4 = make_predict( 768 + 2, 2)
self._predict_flow3 = make_predict( 384 + 2, 2)
self._predict_flow2 = make_predict( 128 + 2, 2)
self._predict_occ6 = make_predict(1024 , 1)
self._predict_occ5 = make_predict(1024 + 1, 1)
self._predict_occ4 = make_predict( 768 + 1, 1)
self._predict_occ3 = make_predict( 384 + 1, 1)
self._predict_occ2 = make_predict( 128 + 1, 1)
def make_upsample(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=False, bias=False)
self._upsample_flow6_to_5 = make_upsample(2, 2)
self._upsample_flow5_to_4 = make_upsample(2, 2)
self._upsample_flow4_to_3 = make_upsample(2, 2)
self._upsample_flow3_to_2 = make_upsample(2, 2)
self._upsample_occ6_to_5 = make_upsample(1, 1)
self._upsample_occ5_to_4 = make_upsample(1, 1)
self._upsample_occ4_to_3 = make_upsample(1, 1)
self._upsample_occ3_to_2 = make_upsample(1, 1)
def forward(self, conv2_im1, conv3_im1, conv3_im2):
conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)
conv3_1 = self._conv3_1(conv_concat3)
conv4_1 = self._conv4_1(self._conv4(conv3_1))
conv5_1 = self._conv5_1(self._conv5(conv4_1))
conv6_1 = self._conv6_1(self._conv6(conv5_1))
# Flow Decoder
predict_flow6 = self._predict_flow6(conv6_1)
upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)
deconv5 = self._deconv5(conv6_1)
concat5 = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)
predict_flow5 = self._predict_flow5(concat5)
upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)
deconv4 = self._deconv4(concat5)
concat4 = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)
predict_flow4 = self._predict_flow4(concat4)
upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)
deconv3 = self._deconv3(concat4)
concat3 = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)
predict_flow3 = self._predict_flow3(concat3)
upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)
deconv2 = self._deconv2(concat3)
concat2 = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)
predict_flow2 = self._predict_flow2(concat2)
# Occ Decoder
predict_occ6 = self._predict_occ6(conv6_1)
upsampled_occ6_to_5 = self._upsample_occ6_to_5(predict_occ6)
deconv_occ5 = self._deconv_occ5(conv6_1)
concat_occ5 = concatenate_as((conv5_1, deconv_occ5, upsampled_occ6_to_5), conv5_1, dim=1)
predict_occ5 = self._predict_occ5(concat_occ5)
upsampled_occ5_to_4 = self._upsample_occ5_to_4(predict_occ5)
deconv_occ4 = self._deconv_occ4(concat_occ5)
concat_occ4 = concatenate_as((conv4_1, deconv_occ4, upsampled_occ5_to_4), conv4_1, dim=1)
predict_occ4 = self._predict_occ4(concat_occ4)
upsampled_occ4_to_3 = self._upsample_occ4_to_3(predict_occ4)
deconv_occ3 = self._deconv_occ3(concat_occ4)
concat_occ3 = concatenate_as((conv3_1, deconv_occ3, upsampled_occ4_to_3), conv3_1, dim=1)
predict_occ3 = self._predict_occ3(concat_occ3)
upsampled_occ3_to_2 = self._upsample_occ3_to_2(predict_occ3)
deconv_occ2 = self._deconv_occ2(concat_occ3)
concat_occ2 = concatenate_as((conv2_im1, deconv_occ2, upsampled_occ3_to_2), conv2_im1, dim=1)
predict_occ2 = self._predict_occ2(concat_occ2)
return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6, predict_occ2, predict_occ3, predict_occ4, predict_occ5, predict_occ6
class FlowNet1S(nn.Module):
def __init__(self, args, div_flow=0.05):
super(FlowNet1S, self).__init__()
self._flownets = FlowNetS(args)
self._warping_layer = WarpingLayer()
self._div_flow = div_flow
self._num_iters = args.num_iters
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv1 = make_conv( 3, 32, kernel_size=7, stride=2)
self._conv2 = make_conv( 32, 64, kernel_size=5, stride=2)
self._conv3 = make_conv( 64, 128, kernel_size=5, stride=2)
initialize_msra(self.modules())
def forward(self, input_dict):
im1 = input_dict['input1']
im2 = input_dict['input2']
conv1_im1 = self._conv1(im1)
conv2_im1 = self._conv2(conv1_im1)
conv3_im1 = self._conv3(conv2_im1)
conv1_im2 = self._conv1(im2)
conv2_im2 = self._conv2(conv1_im2)
conv3_im2 = self._conv3(conv2_im2)
conv3_im2_wp = conv3_im2
output_dict = {}
output_dict['flow2'] = []
output_dict['flow3'] = []
output_dict['flow4'] = []
output_dict['flow5'] = []
output_dict['flow6'] = []
output_dict['occ2'] = []
output_dict['occ3'] = []
output_dict['occ4'] = []
output_dict['occ5'] = []
output_dict['occ6'] = []
_, _, height_im, width_im = im1.size()
# for iterative
for ii in range(0, self._num_iters):
flow2, flow3, flow4, flow5, flow6, occ2, occ3, occ4, occ5, occ6 = self._flownets(conv2_im1, conv3_im1, conv3_im2_wp)
if ii == 0:
output_dict['flow2'].append(flow2)
output_dict['flow3'].append(flow3)
output_dict['flow4'].append(flow4)
output_dict['flow5'].append(flow5)
output_dict['flow6'].append(flow6)
output_dict['occ2'].append(occ2)
output_dict['occ3'].append(occ3)
output_dict['occ4'].append(occ4)
output_dict['occ5'].append(occ5)
output_dict['occ6'].append(occ6)
else:
output_dict['flow2'].append(flow2 + output_dict['flow2'][ii - 1])
output_dict['flow3'].append(flow3 + output_dict['flow3'][ii - 1])
output_dict['flow4'].append(flow4 + output_dict['flow4'][ii - 1])
output_dict['flow5'].append(flow5 + output_dict['flow5'][ii - 1])
output_dict['flow6'].append(flow6 + output_dict['flow6'][ii - 1])
output_dict['occ2'].append(occ2 + output_dict['occ2'][ii - 1])
output_dict['occ3'].append(occ3 + output_dict['occ3'][ii - 1])
output_dict['occ4'].append(occ4 + output_dict['occ4'][ii - 1])
output_dict['occ5'].append(occ5 + output_dict['occ5'][ii - 1])
output_dict['occ6'].append(occ6 + output_dict['occ6'][ii - 1])
if ii < (self._num_iters - 1):
up_flow = upsample2d_as(output_dict['flow2'][ii], conv3_im2, mode="bilinear")
conv3_im2_wp = self._warping_layer(conv3_im2, up_flow, height_im, width_im, self._div_flow)
if self.training:
return output_dict
else:
output_dict_eval = {}
up_flow_final = upsample2d_as(output_dict['flow2'][self._num_iters - 1], im1, mode="bilinear")
up_occ_final = upsample2d_as(output_dict['occ2'][self._num_iters - 1], im1, mode="bilinear")
output_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final
output_dict_eval['occ1'] = up_occ_final
return output_dict_eval
================================================
FILE: models/flownet1s_irr_occ_bi.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .flownet_modules import conv, deconv
from .flownet_modules import concatenate_as, upsample2d_as
from .flownet_modules import initialize_msra
from .flownet_modules import WarpingLayer
class FlowNetS(nn.Module):
def __init__(self, args):
super(FlowNetS, self).__init__()
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv3_1 = make_conv( 256, 256, kernel_size=3, stride=1)
self._conv4 = make_conv( 256, 512, kernel_size=3, stride=2)
self._conv4_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv5 = make_conv( 512, 512, kernel_size=3, stride=2)
self._conv5_1 = make_conv( 512, 512, kernel_size=3, stride=1)
self._conv6 = make_conv( 512, 1024, kernel_size=3, stride=2)
self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)
def make_deconv(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=True, bias=False)
self._deconv5 = make_deconv(1024 , 512)
self._deconv4 = make_deconv(1024 + 2, 256)
self._deconv3 = make_deconv( 768 + 2, 128)
self._deconv2 = make_deconv( 384 + 2, 64)
self._deconv_occ5 = make_deconv(1024 , 512)
self._deconv_occ4 = make_deconv(1024 + 1, 256)
self._deconv_occ3 = make_deconv( 768 + 1, 128)
self._deconv_occ2 = make_deconv( 384 + 1, 64)
def make_predict(in_planes, out_planes):
return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,
nonlinear=False, bias=True)
self._predict_flow6 = make_predict(1024 , 2)
self._predict_flow5 = make_predict(1024 + 2, 2)
self._predict_flow4 = make_predict( 768 + 2, 2)
self._predict_flow3 = make_predict( 384 + 2, 2)
self._predict_flow2 = make_predict( 128 + 2, 2)
self._predict_occ6 = make_predict(1024 , 1)
self._predict_occ5 = make_predict(1024 + 1, 1)
self._predict_occ4 = make_predict( 768 + 1, 1)
self._predict_occ3 = make_predict( 384 + 1, 1)
self._predict_occ2 = make_predict( 128 + 1, 1)
def make_upsample(in_planes, out_planes):
return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,
nonlinear=False, bias=False)
self._upsample_flow6_to_5 = make_upsample(2, 2)
self._upsample_flow5_to_4 = make_upsample(2, 2)
self._upsample_flow4_to_3 = make_upsample(2, 2)
self._upsample_flow3_to_2 = make_upsample(2, 2)
self._upsample_occ6_to_5 = make_upsample(1, 1)
self._upsample_occ5_to_4 = make_upsample(1, 1)
self._upsample_occ4_to_3 = make_upsample(1, 1)
self._upsample_occ3_to_2 = make_upsample(1, 1)
def forward(self, conv2_im1, conv3_im1, conv3_im2):
conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)
conv3_1 = self._conv3_1(conv_concat3)
conv4_1 = self._conv4_1(self._conv4(conv3_1))
conv5_1 = self._conv5_1(self._conv5(conv4_1))
conv6_1 = self._conv6_1(self._conv6(conv5_1))
# Flow Decoder
predict_flow6 = self._predict_flow6(conv6_1)
upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)
deconv5 = self._deconv5(conv6_1)
concat5 = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)
predict_flow5 = self._predict_flow5(concat5)
upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)
deconv4 = self._deconv4(concat5)
concat4 = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)
predict_flow4 = self._predict_flow4(concat4)
upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)
deconv3 = self._deconv3(concat4)
concat3 = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)
predict_flow3 = self._predict_flow3(concat3)
upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)
deconv2 = self._deconv2(concat3)
concat2 = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)
predict_flow2 = self._predict_flow2(concat2)
# Occ Decoder
predict_occ6 = self._predict_occ6(conv6_1)
upsampled_occ6_to_5 = self._upsample_occ6_to_5(predict_occ6)
deconv_occ5 = self._deconv_occ5(conv6_1)
concat_occ5 = concatenate_as((conv5_1, deconv_occ5, upsampled_occ6_to_5), conv5_1, dim=1)
predict_occ5 = self._predict_occ5(concat_occ5)
upsampled_occ5_to_4 = self._upsample_occ5_to_4(predict_occ5)
deconv_occ4 = self._deconv_occ4(concat_occ5)
concat_occ4 = concatenate_as((conv4_1, deconv_occ4, upsampled_occ5_to_4), conv4_1, dim=1)
predict_occ4 = self._predict_occ4(concat_occ4)
upsampled_occ4_to_3 = self._upsample_occ4_to_3(predict_occ4)
deconv_occ3 = self._deconv_occ3(concat_occ4)
concat_occ3 = concatenate_as((conv3_1, deconv_occ3, upsampled_occ4_to_3), conv3_1, dim=1)
predict_occ3 = self._predict_occ3(concat_occ3)
upsampled_occ3_to_2 = self._upsample_occ3_to_2(predict_occ3)
deconv_occ2 = self._deconv_occ2(concat_occ3)
concat_occ2 = concatenate_as((conv2_im1, deconv_occ2, upsampled_occ3_to_2), conv2_im1, dim=1)
predict_occ2 = self._predict_occ2(concat_occ2)
return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6, predict_occ2, predict_occ3, predict_occ4, predict_occ5, predict_occ6
class FlowNet1S(nn.Module):
def __init__(self, args, div_flow=0.05):
super(FlowNet1S, self).__init__()
self._flownets = FlowNetS(args)
self._warping_layer = WarpingLayer()
self._div_flow = div_flow
self._num_iters = args.num_iters
def make_conv(in_planes, out_planes, kernel_size, stride):
pad = kernel_size // 2
return conv(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, pad=pad, nonlinear=True, bias=True)
self._conv1 = make_conv( 3, 32, kernel_size=7, stride=2)
self._conv2 = make_conv( 32, 64, kernel_size=5, stride=2)
self._conv3 = make_conv( 64, 128, kernel_size=5, stride=2)
initialize_msra(self.modules())
def forward(self, input_dict):
im1 = input_dict['input1']
im2 = input_dict['input2']
conv1_im1 = self._conv1(im1)
conv2_im1 = self._conv2(conv1_im1)
conv3_im1 = self._conv3(conv2_im1)
conv3_im1_wp = conv3_im1
conv1_im2 = self._conv1(im2)
conv2_im2 = self._conv2(conv1_im2)
conv3_im2 = self._conv3(conv2_im2)
conv3_im2_wp = conv3_im2
out_dict = {}
out_dict['flow2'] = []
out_dict['flow3'] = []
out_dict['flow4'] = []
out_dict['flow5'] = []
out_dict['flow6'] = []
out_dict['occ2'] = []
out_dict['occ3'] = []
out_dict['occ4'] = []
out_dict['occ5'] = []
out_dict['occ6'] = []
_, _, height_im, width_im = im1.size()
# for iterative
for ii in range(0, self._num_iters):
flo2_f, flo3_f, flo4_f, flo5_f, flo6_f, occ2_f, occ3_f, occ4_f, occ5_f, occ6_f = self._flownets(conv2_im1,
conv3_im1,
conv3_im2_wp)
flo2_b, flo3_b, flo4_b, flo5_b, flo6_b, occ2_b, occ3_b, occ4_b, occ5_b, occ6_b = self._flownets(conv2_im2,
conv3_im2,
conv3_im1_wp)
if ii == 0:
out_dict['flow2'].append([flo2_f, flo2_b])
out_dict['flow3'].append([flo3_f, flo3_b])
out_dict['flow4'].append([flo4_f, flo4_b])
out_dict['flow5'].append([flo5_f, flo5_b])
out_dict['flow6'].append([flo6_f, flo6_b])
out_dict['occ2'].append([occ2_f, occ2_b])
out_dict['occ3'].append([occ3_f, occ3_b])
out_dict['occ4'].append([occ4_f, occ4_b])
out_dict['occ5'].append([occ5_f, occ5_b])
out_dict['occ6'].append([occ6_f, occ6_b])
else:
out_dict['flow2'].append([flo2_f + out_dict['flow2'][ii - 1][0], flo2_b + out_dict['flow2'][ii - 1][1]])
out_dict['flow3'].append([flo3_f + out_dict['flow3'][ii - 1][0], flo3_b + out_dict['flow3'][ii - 1][1]])
out_dict['flow4'].append([flo4_f + out_dict['flow4'][ii - 1][0], flo4_b + out_dict['flow4'][ii - 1][1]])
out_dict['flow5'].append([flo5_f + out_dict['flow5'][ii - 1][0], flo5_b + out_dict['flow5'][ii - 1][1]])
out_dict['flow6'].append([flo6_f + out_dict['flow6'][ii - 1][0], flo6_b + out_dict['flow6'][ii - 1][1]])
out_dict['occ2'].append([occ2_f + out_dict['occ2'][ii - 1][0], occ2_b + out_dict['occ2'][ii - 1][1]])
out_dict['occ3'].append([occ3_f + out_dict['occ3'][ii - 1][0], occ3_b + out_dict['occ3'][ii - 1][1]])
out_dict['occ4'].append([occ4_f + out_dict['occ4'][ii - 1][0], occ4_b + out_dict['occ4'][ii - 1][1]])
out_dict['occ5'].append([occ5_f + out_dict['occ5'][ii - 1][0], occ5_b + out_dict['occ5'][ii - 1][1]])
out_dict['occ6'].append([occ6_f + out_dict['occ6'][ii - 1][0], occ6_b + out_dict['occ6'][ii - 1][1]])
if ii < (self._num_iters - 1):
up_flow_f_c3 = upsample2d_as(out_dict['flow2'][ii][0], conv3_im2, mode="bilinear")
up_flow_b_c3 = upsample2d_as(out_dict['flow2'][ii][1], conv3_im1, mode="bilinear")
conv3_im2_wp = self._warping_layer(conv3_im2, up_flow_f_c3, height_im, width_im, self._div_flow)
conv3_im1_wp = self._warping_layer(conv3_im1, up_flow_b_c3, height_im, width_im, self._div_flow)
if self.training:
return out_dict
else:
out_dict_eval = {}
up_flow_final = upsample2d_as(out_dict['flow2'][self._num_iters - 1][0], im1, mode="bilinear")
up_occ_final = upsample2d_as(out_dict['occ2'][self._num_iters - 1][0], im1, mode="bilinear")
out_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final
out_dict_eval['occ1'] = up_occ_final
return out_dict_eval
================================================
FILE: models/flownet_modules.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as tf
import logging
def conv(in_planes, out_planes, kernel_size, stride, pad, nonlinear, bias):
if nonlinear:
return nn.Sequential(
nn.Conv2d(
in_planes, out_planes, kernel_size=kernel_size,
stride=stride, padding=pad, bias=bias),
nn.LeakyReLU(0.1, inplace=True)
)
else:
return nn.Conv2d(
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=pad, bias=bias)
def deconv(in_planes, out_planes, kernel_size, stride, pad, nonlinear, bias):
if nonlinear:
return nn.Sequential(
nn.ConvTranspose2d(
in_planes, out_planes, kernel_size=kernel_size,
stride=stride, padding=pad, bias=bias),
nn.LeakyReLU(0.1, inplace=True)
)
else:
return nn.ConvTranspose2d(
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=pad, bias=bias)
def resize2D(inputs, size_targets, mode="bilinear"):
size_inputs = [inputs.size(2), inputs.size(3)]
if all([size_inputs == size_targets]):
return inputs # nothing to do
elif any([size_targets < size_inputs]):
resized = tf.adaptive_avg_pool2d(inputs, size_targets) # downscaling
else:
resized = tf.interpolate(inputs, size=size_targets, mode=mode, align_corners=True)
return resized
def resize2D_as(inputs, output_as, mode="bilinear"):
size_targets = [output_as.size(2), output_as.size(3)]
return resize2D(inputs, size_targets, mode=mode)
def concatenate_as(tensor_list, tensor_as, dim, mode="bilinear"):
tensor_list = [resize2D_as(x, tensor_as, mode=mode) for x in tensor_list]
return torch.cat(tensor_list, dim=dim)
def upsample2d_as(inputs, target_as, mode="bilinear"):
_, _, h, w = target_as.size()
return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
def initialize_msra(modules):
logging.info("Initializing MSRA")
for layer in modules:
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.ConvTranspose2d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.LeakyReLU):
pass
elif isinstance(layer, nn.Sequential):
pass
elif "models" in str(type(layer)) and "FlowNet" in str(type(layer)):
pass
def get_grid(x):
grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))
grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))
grid = torch.cat([grid_H, grid_V], 1)
grids_cuda = grid.float().requires_grad_(False).cuda()
return grids_cuda
class WarpingLayer(nn.Module):
def __init__(self):
super(WarpingLayer, self).__init__()
def forward(self, x, flow, height_im, width_im, div_flow):
flo_list = []
flo_w = flow[:, 0] * 2 / width_im / div_flow
flo_h = flow[:, 1] * 2 / height_im / div_flow
flo_list.append(flo_w)
flo_list.append(flo_h)
flow_for_grid = torch.stack(flo_list).transpose(0, 1)
grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3)
x_warp = tf.grid_sample(x, grid, align_corners=True)
return x_warp
================================================
FILE: models/irr_modules.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as tf
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
if isReLU:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2, bias=True),
nn.LeakyReLU(0.1, inplace=True)
)
else:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2, bias=True)
)
def upsample_factor2(inputs, target_as):
inputs = tf.interpolate(inputs, scale_factor=2, mode="nearest")
_, _, h, w = target_as.size()
if inputs.size(2) != h or inputs.size(3) != w:
return tf.interpolate(inputs, [h, w], mode="bilinear", align_corners=False)
else:
return inputs
class OccUpsampleNetwork(nn.Module):
def __init__(self, ch_in, ch_out):
super(OccUpsampleNetwork, self).__init__()
self.feat_dim = 32
self.init_conv = conv(ch_in, self.feat_dim)
self.res_convs = nn.Sequential(
conv(self.feat_dim, self.feat_dim),
conv(self.feat_dim, self.feat_dim, isReLU=False)
)
self.res_end_conv = conv(self.feat_dim, self.feat_dim)
self.mul_const = 0.1
self.out_convs = conv(self.feat_dim, ch_out)
def forward(self, occ, x):
occ = upsample_factor2(occ, x)
x_in = torch.cat([occ, x], dim=1)
x_init = self.init_conv(x_in)
x_res = x_init
x_res = x_res + self.res_convs(x_res) * self.mul_const
x_res = x_res + self.res_convs(x_res) * self.mul_const
x_res = x_res + self.res_convs(x_res) * self.mul_const
x_init = x_init + self.res_end_conv(x_res)
return self.out_convs(x_init) + occ
def subtract_mean(input):
return input - input.mean(2).mean(2).unsqueeze(2).unsqueeze(2).expand_as(input)
class RefineFlow(nn.Module):
def __init__(self, ch_in):
super(RefineFlow, self).__init__()
self.kernel_size = 3
self.pad_size = 1
self.pad_ftn = nn.ReplicationPad2d(self.pad_size)
self.convs = nn.Sequential(
conv(ch_in, 128, 3, 1, 1),
conv(128, 128, 3, 1, 1),
conv(128, 64, 3, 1, 1),
conv(64, 64, 3, 1, 1),
conv(64, 32, 3, 1, 1),
conv(32, 32, 3, 1, 1),
conv(32, self.kernel_size * self.kernel_size, 3, 1, 1)
)
self.softmax_feat = nn.Softmax(dim=1)
self.unfold_flow = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size))
self.unfold_kernel = nn.Unfold(kernel_size=(1, 1))
def forward(self, flow, diff_img, feature):
b, _, h, w = flow.size()
flow_m = subtract_mean(flow)
norm2_img = torch.norm(diff_img, p=2, dim=1, keepdim=True)
feat = self.convs(torch.cat([flow_m, norm2_img, feature], dim=1))
feat_kernel = self.softmax_feat(-feat ** 2)
flow_x = flow[:, 0].unsqueeze(1)
flow_y = flow[:, 1].unsqueeze(1)
flow_x_unfold = self.unfold_flow(self.pad_ftn(flow_x))
flow_y_unfold = self.unfold_flow(self.pad_ftn(flow_y))
feat_kernel_unfold = self.unfold_kernel(feat_kernel)
flow_out_x = torch.sum(flow_x_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)
flow_out_y = torch.sum(flow_y_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)
return torch.cat([flow_out_x, flow_out_y], dim=1)
class RefineOcc(nn.Module):
def __init__(self, ch_in):
super(RefineOcc, self).__init__()
self.kernel_size = 3
self.pad_size = 1
self.pad_ftn = nn.ReplicationPad2d(self.pad_size)
self.convs = nn.Sequential(
conv(ch_in, 128, 3, 1, 1),
conv(128, 128, 3, 1, 1),
conv(128, 64, 3, 1, 1),
conv(64, 64, 3, 1, 1),
conv(64, 32, 3, 1, 1),
conv(32, 32, 3, 1, 1),
conv(32, self.kernel_size * self.kernel_size, 3, 1, 1)
)
self.softmax_feat = nn.Softmax(dim=1)
self.unfold_occ = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size))
self.unfold_kernel = nn.Unfold(kernel_size=(1, 1))
def forward(self, occ, feat1, feat2):
b, _, h, w = occ.size()
feat = self.convs(torch.cat([occ, feat1, feat2], dim=1))
feat_kernel = self.softmax_feat(-feat ** 2)
occ_unfold = self.unfold_occ(self.pad_ftn(occ))
feat_kernel_unfold = self.unfold_kernel(feat_kernel)
occ_out = torch.sum(occ_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)
return occ_out
================================================
FILE: models/pwc_modules.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as tf
import logging
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
if isReLU:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2, bias=True),
nn.LeakyReLU(0.1, inplace=True)
)
else:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2, bias=True)
)
def initialize_msra(modules):
logging.info("Initializing MSRA")
for layer in modules:
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.ConvTranspose2d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.LeakyReLU):
pass
elif isinstance(layer, nn.Sequential):
pass
def compute_cost_volume(feat1, feat2, param_dict):
"""
only implemented for:
kernel_size = 1
stride1 = 1
stride2 = 1
"""
max_disp = param_dict["max_disp"]
_, _, height, width = feat1.size()
num_shifts = 2 * max_disp + 1
feat2_padded = tf.pad(feat2, (max_disp, max_disp, max_disp, max_disp), "constant", 0)
cost_list = []
for i in range(num_shifts):
for j in range(num_shifts):
corr = torch.mean(feat1 * feat2_padded[:, :, i:(height + i), j:(width + j)], axis=1, keepdims=True)
cost_list.append(corr)
cost_volume = torch.cat(cost_list, axis=1)
return cost_volume
def upsample2d_as(inputs, target_as, mode="bilinear"):
_, _, h, w = target_as.size()
return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
def rescale_flow(flow, div_flow, width_im, height_im, to_local=True):
if to_local:
u_scale = float(flow.size(3) / width_im / div_flow)
v_scale = float(flow.size(2) / height_im / div_flow)
else:
u_scale = float(width_im * div_flow / flow.size(3))
v_scale = float(height_im * div_flow / flow.size(2))
u, v = flow.chunk(2, dim=1)
u *= u_scale
v *= v_scale
return torch.cat([u, v], dim=1)
class FeatureExtractor(nn.Module):
def __init__(self, num_chs):
super(FeatureExtractor, self).__init__()
self.num_chs = num_chs
self.convs = nn.ModuleList()
for l, (ch_in, ch_out) in enumerate(zip(num_chs[:-1], num_chs[1:])):
layer = nn.Sequential(
conv(ch_in, ch_out, stride=2),
conv(ch_out, ch_out)
)
self.convs.append(layer)
def forward(self, x):
feature_pyramid = []
for conv in self.convs:
x = conv(x)
feature_pyramid.append(x)
return feature_pyramid[::-1]
def get_grid(x):
grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))
grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))
grid = torch.cat([grid_H, grid_V], 1)
grids_cuda = grid.float().requires_grad_(False).cuda()
return grids_cuda
class WarpingLayer(nn.Module):
def __init__(self):
super(WarpingLayer, self).__init__()
def forward(self, x, flow, height_im, width_im, div_flow):
flo_list = []
flo_w = flow[:, 0] * 2 / max(width_im - 1, 1) / div_flow
flo_h = flow[:, 1] * 2 / max(height_im - 1, 1) / div_flow
flo_list.append(flo_w)
flo_list.append(flo_h)
flow_for_grid = torch.stack(flo_list).transpose(0, 1)
grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3)
x_warp = tf.grid_sample(x, grid, align_corners=True)
mask = torch.ones(x.size(), requires_grad=False).cuda()
mask = tf.grid_sample(mask, grid, align_corners=True)
mask = (mask >= 1.0).float()
return x_warp * mask
class OpticalFlowEstimator(nn.Module):
def __init__(self, ch_in):
super(OpticalFlowEstimator, self).__init__()
self.convs = nn.Sequential(
conv(ch_in, 128),
conv(128, 128),
conv(128, 96),
conv(96, 64),
conv(64, 32)
)
self.conv_last = conv(32, 2, isReLU=False)
def forward(self, x):
x_intm = self.convs(x)
return x_intm, self.conv_last(x_intm)
class FlowEstimatorDense(nn.Module):
def __init__(self, ch_in):
super(FlowEstimatorDense, self).__init__()
self.conv1 = conv(ch_in, 128)
self.conv2 = conv(ch_in + 128, 128)
self.conv3 = conv(ch_in + 256, 96)
self.conv4 = conv(ch_in + 352, 64)
self.conv5 = conv(ch_in + 416, 32)
self.conv_last = conv(ch_in + 448, 2, isReLU=False)
def forward(self, x):
x1 = torch.cat([self.conv1(x), x], dim=1)
x2 = torch.cat([self.conv2(x1), x1], dim=1)
x3 = torch.cat([self.conv3(x2), x2], dim=1)
x4 = torch.cat([self.conv4(x3), x3], dim=1)
x5 = torch.cat([self.conv5(x4), x4], dim=1)
x_out = self.conv_last(x5)
return x5, x_out
class OcclusionEstimator(nn.Module):
def __init__(self, ch_in):
super(OcclusionEstimator, self).__init__()
self.convs = nn.Sequential(
conv(ch_in, 128),
conv(128, 128),
conv(128, 96),
conv(96, 64),
conv(64, 32)
)
self.conv_last = conv(32, 1, isReLU=False)
def forward(self, x):
x_intm = self.convs(x)
return x_intm, self.conv_last(x_intm)
class OccEstimatorDense(nn.Module):
def __init__(self, ch_in):
super(OccEstimatorDense, self).__init__()
self.conv1 = conv(ch_in, 128)
self.conv2 = conv(ch_in + 128, 128)
self.conv3 = conv(ch_in + 256, 96)
self.conv4 = conv(ch_in + 352, 64)
self.conv5 = conv(ch_in + 416, 32)
self.conv_last = conv(ch_in + 448, 1, isReLU=False)
def forward(self, x):
x1 = torch.cat([self.conv1(x), x], dim=1)
x2 = torch.cat([self.conv2(x1), x1], dim=1)
x3 = torch.cat([self.conv3(x2), x2], dim=1)
x4 = torch.cat([self.conv4(x3), x3], dim=1)
x5 = torch.cat([self.conv5(x4), x4], dim=1)
x_out = self.conv_last(x5)
return x5, x_out
class ContextNetwork(nn.Module):
def __init__(self, ch_in):
super(ContextNetwork, self).__init__()
self.convs = nn.Sequential(
conv(ch_in, 128, 3, 1, 1),
conv(128, 128, 3, 1, 2),
conv(128, 128, 3, 1, 4),
conv(128, 96, 3, 1, 8),
conv(96, 64, 3, 1, 16),
conv(64, 32, 3, 1, 1),
conv(32, 2, isReLU=False)
)
def forward(self, x):
return self.convs(x)
class OccContextNetwork(nn.Module):
def __init__(self, ch_in):
super(OccContextNetwork, self).__init__()
self.convs = nn.Sequential(
conv(ch_in, 128, 3, 1, 1),
conv(128, 128, 3, 1, 2),
conv(128, 128, 3, 1, 4),
conv(128, 96, 3, 1, 8),
conv(96, 64, 3, 1, 16),
conv(64, 32, 3, 1, 1),
conv(32, 1, isReLU=False)
)
def forward(self, x):
return self.convs(x)
================================================
FILE: models/pwcnet.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.flow_estimators = nn.ModuleList()
self.dim_corr = (self.search_range * 2 + 1) ** 2
for l, ch in enumerate(self.num_chs[::-1]):
if l > self.output_level:
break
if l == 0:
num_ch_in = self.dim_corr
else:
num_ch_in = self.dim_corr + ch + 2
layer = FlowEstimatorDense(num_ch_in)
self.flow_estimators.append(layer)
self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
else:
flow = upsample2d_as(flow, x1, mode="bilinear")
x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
# correlation
out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_relu = self.leakyRELU(out_corr)
# flow estimator
if l == 0:
x_intm, flow = self.flow_estimators[l](out_corr_relu)
else:
x_intm, flow = self.flow_estimators[l](torch.cat([out_corr_relu, x1, flow], dim=1))
# upsampling or post-processing
if l != self.output_level:
flows.append(flow)
else:
flow_res = self.context_networks(torch.cat([x_intm, flow], dim=1))
flow = flow + flow_res
flows.append(flow)
break
output_dict['flow'] = flows
if self.training:
return output_dict
else:
output_dict_eval = {}
out_flow = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['flow'] = out_flow
return output_dict_eval
================================================
FILE: models/pwcnet_bi.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.flow_estimators = nn.ModuleList()
self.dim_corr = (self.search_range * 2 + 1) ** 2
for l, ch in enumerate(self.num_chs[::-1]):
if l > self.output_level:
break
if l == 0:
num_ch_in = self.dim_corr
else:
num_ch_in = self.dim_corr + ch + 2
layer = FlowEstimatorDense(num_ch_in)
self.flow_estimators.append(layer)
self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
x1_warp = x1
else:
flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
# correlation
out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)
out_corr_relu_f = self.leakyRELU(out_corr_f)
out_corr_relu_b = self.leakyRELU(out_corr_b)
# flow estimator
if l == 0:
x_intm_f, flow_f = self.flow_estimators[l](out_corr_relu_f)
x_intm_b, flow_b = self.flow_estimators[l](out_corr_relu_b)
else:
x_intm_f, flow_f = self.flow_estimators[l](torch.cat([out_corr_relu_f, x1, flow_f], dim=1))
x_intm_b, flow_b = self.flow_estimators[l](torch.cat([out_corr_relu_b, x2, flow_b], dim=1))
# upsampling or post-processing
if l != self.output_level:
flows.append([flow_f, flow_b])
else:
flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))
flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))
flow_f = flow_f + flow_fine_f
flow_b = flow_b + flow_fine_b
flows.append([flow_f, flow_b])
break
output_dict['flow'] = flows
if self.training:
return output_dict
else:
output_dict_eval = {}
out_flow = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['flow'] = out_flow
return output_dict_eval
================================================
FILE: models/pwcnet_irr.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.dim_corr = (self.search_range * 2 + 1) ** 2
self.num_ch_in = self.dim_corr + 32 + 2
self.flow_estimators = FlowEstimatorDense(self.num_ch_in)
self.context_networks = ContextNetwork(self.num_ch_in + 448 + 2)
self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
conv(128, 32, kernel_size=1, stride=1, dilation=1),
conv(96, 32, kernel_size=1, stride=1, dilation=1),
conv(64, 32, kernel_size=1, stride=1, dilation=1),
conv(32, 32, kernel_size=1, stride=1, dilation=1)])
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
else:
flow = upsample2d_as(flow, x1, mode="bilinear")
x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
# correlation
out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_relu = self.leakyRELU(out_corr)
# concat and estimate flow
flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True)
x1_1by1 = self.conv_1x1[l](x1)
x_intm, flow_res = self.flow_estimators(torch.cat([out_corr_relu, x1_1by1, flow], dim=1))
flow = flow + flow_res
flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))
flow = flow + flow_fine
flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False)
flows.append(flow)
# upsampling or post-processing
if l == self.output_level:
break
output_dict['flow'] = flows
if self.training:
return output_dict
else:
output_dict_eval = {}
output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
return output_dict_eval
================================================
FILE: models/pwcnet_irr_bi.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.dim_corr = (self.search_range * 2 + 1) ** 2
self.num_ch_in = self.dim_corr + 32 + 2
self.flow_estimators = FlowEstimatorDense(self.num_ch_in)
self.context_networks = ContextNetwork(self.num_ch_in + 448 + 2)
self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
conv(128, 32, kernel_size=1, stride=1, dilation=1),
conv(96, 32, kernel_size=1, stride=1, dilation=1),
conv(64, 32, kernel_size=1, stride=1, dilation=1),
conv(32, 32, kernel_size=1, stride=1, dilation=1)])
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
x1_warp = x1
else:
flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
# correlation
out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)
out_corr_relu_f = self.leakyRELU(out_corr_f)
out_corr_relu_b = self.leakyRELU(out_corr_b)
# concat and estimate flow
flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)
flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)
x1_1by1 = self.conv_1x1[l](x1)
x2_1by1 = self.conv_1x1[l](x2)
x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))
x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))
flow_f = flow_f + flow_res_f
flow_b = flow_b + flow_res_b
flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))
flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))
flow_f = flow_f + flow_fine_f
flow_b = flow_b + flow_fine_b
flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)
flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)
flows.append([flow_f, flow_b])
# upsampling or post-processing
if l == self.output_level:
break
output_dict['flow'] = flows
if self.training:
return output_dict
else:
output_dict_eval = {}
output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
return output_dict_eval
================================================
FILE: models/pwcnet_irr_occ.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.dim_corr = (self.search_range * 2 + 1) ** 2
self.num_ch_in_flo = self.dim_corr + 32 + 2
self.num_ch_in_occ = self.dim_corr + 32 + 1
self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)
self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)
self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)
self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)
self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
conv(128, 32, kernel_size=1, stride=1, dilation=1),
conv(96, 32, kernel_size=1, stride=1, dilation=1),
conv(64, 32, kernel_size=1, stride=1, dilation=1),
conv(32, 32, kernel_size=1, stride=1, dilation=1)])
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
occs = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
else:
flow = upsample2d_as(flow, x1, mode="bilinear")
occ = upsample2d_as(occ, x1, mode="bilinear")
x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
# correlation
out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_relu = self.leakyRELU(out_corr)
# concat and estimate flow
flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True)
x1_1by1 = self.conv_1x1[l](x1)
x_intm, flow_res = self.flow_estimators(torch.cat([out_corr_relu, x1_1by1, flow], dim=1))
flow = flow + flow_res
flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))
flow = flow + flow_fine
flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False)
flows.append(flow)
x_intm_occ, occ_res = self.occ_estimators(torch.cat([out_corr_relu, x1_1by1, occ], dim=1))
occ = occ + occ_res
occ_fine = self.occ_context_networks(torch.cat([x_intm_occ, occ], dim=1))
occ = occ + occ_fine
occs.append(occ)
# upsampling or post-processing
if l == self.output_level:
break
output_dict['flow'] = flows
output_dict['occ'] = occs
if self.training:
return output_dict
else:
output_dict_eval = {}
output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode="bilinear")
return output_dict_eval
================================================
FILE: models/pwcnet_irr_occ_bi.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.dim_corr = (self.search_range * 2 + 1) ** 2
self.num_ch_in_flo = self.dim_corr + 32 + 2
self.num_ch_in_occ = self.dim_corr + 32 + 1
self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)
self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)
self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)
self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)
self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
conv(128, 32, kernel_size=1, stride=1, dilation=1),
conv(96, 32, kernel_size=1, stride=1, dilation=1),
conv(64, 32, kernel_size=1, stride=1, dilation=1),
conv(32, 32, kernel_size=1, stride=1, dilation=1)])
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
occs = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
x1_warp = x1
else:
flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
occ_f = upsample2d_as(occ_f, x1, mode="bilinear")
occ_b = upsample2d_as(occ_b, x2, mode="bilinear")
x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
# correlation
out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)
out_corr_relu_f = self.leakyRELU(out_corr_f)
out_corr_relu_b = self.leakyRELU(out_corr_b)
# concat and estimate flow
flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)
flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)
x1_1by1 = self.conv_1x1[l](x1)
x2_1by1 = self.conv_1x1[l](x2)
x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))
x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))
flow_f = flow_f + flow_res_f
flow_b = flow_b + flow_res_b
flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))
flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))
flow_f = flow_f + flow_fine_f
flow_b = flow_b + flow_fine_b
flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)
flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)
flows.append([flow_f, flow_b])
# occ estimation
x_intm_occ_f, occ_res_f = self.occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, occ_f], dim=1))
x_intm_occ_b, occ_res_b = self.occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, occ_b], dim=1))
occ_f = occ_f + occ_res_f
occ_b = occ_b + occ_res_b
occ_fine_f = self.occ_context_networks(torch.cat([x_intm_occ_f, occ_f], dim=1))
occ_fine_b = self.occ_context_networks(torch.cat([x_intm_occ_b, occ_b], dim=1))
occ_f = occ_f + occ_fine_f
occ_b = occ_b + occ_fine_b
occs.append([occ_f, occ_b])
# upsampling or post-processing
if l == self.output_level:
break
output_dict['flow'] = flows
output_dict['occ'] = occs
if self.training:
return output_dict
else:
output_dict_eval = {}
output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear")
return output_dict_eval
================================================
FILE: models/pwcnet_occ.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.flow_estimators = nn.ModuleList()
self.occ_estimators = nn.ModuleList()
self.dim_corr = (self.search_range * 2 + 1) ** 2
for l, ch in enumerate(self.num_chs[::-1]):
if l > self.output_level:
break
if l == 0:
num_ch_in = self.dim_corr
num_ch_in_occ = self.dim_corr
else:
num_ch_in = self.dim_corr + ch + 2
num_ch_in_occ = self.dim_corr + ch + 1
layer = FlowEstimatorDense(num_ch_in)
layer_occ = OccEstimatorDense(num_ch_in_occ)
self.flow_estimators.append(layer)
self.occ_estimators.append(layer_occ)
self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)
self.context_networks_occ = OccContextNetwork(self.dim_corr + 32 + 1 + 448 + 1)
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
occs = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
else:
flow = upsample2d_as(flow, x1, mode="bilinear")
occ = upsample2d_as(occ, x1, mode="bilinear")
x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
# correlation
out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_relu = self.leakyRELU(out_corr)
# flow estimator
if l == 0:
x_intm, flow = self.flow_estimators[l](out_corr_relu)
x_intm_occ, occ= self.occ_estimators[l](out_corr_relu)
else:
x_intm, flow = self.flow_estimators[l](torch.cat([out_corr_relu, x1, flow], dim=1))
x_intm_occ, occ = self.occ_estimators[l](torch.cat([out_corr_relu, x1, occ], dim=1))
# upsampling or post-processing
if l != self.output_level:
flows.append(flow)
occs.append(occ)
else:
flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))
flow = flow + flow_fine
flows.append(flow)
occ_fine = self.context_networks_occ(torch.cat([x_intm_occ, occ], dim=1))
occ = occ + occ_fine
occs.append(occ)
break
output_dict['flow'] = flows
output_dict['occ'] = occs
if self.training:
return output_dict
else:
output_dict_eval = {}
output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode="bilinear")
return output_dict_eval
================================================
FILE: models/pwcnet_occ_bi.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume
from .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork
class PWCNet(nn.Module):
def __init__(self, args, div_flow=0.05):
super(PWCNet, self).__init__()
self.args = args
self._div_flow = div_flow
self.search_range = 4
self.num_chs = [3, 16, 32, 64, 96, 128, 196]
self.output_level = 4
self.num_levels = 7
self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
self.warping_layer = WarpingLayer()
self.flow_estimators = nn.ModuleList()
self.occ_estimators = nn.ModuleList()
self.dim_corr = (self.search_range * 2 + 1) ** 2
for l, ch in enumerate(self.num_chs[::-1]):
if l > self.output_level:
break
if l == 0:
num_ch_in = self.dim_corr
num_ch_in_occ = self.dim_corr
else:
num_ch_in = self.dim_corr + ch + 2
num_ch_in_occ = self.dim_corr + ch + 1
layer = FlowEstimatorDense(num_ch_in)
layer_occ = OccEstimatorDense(num_ch_in_occ)
self.flow_estimators.append(layer)
self.occ_estimators.append(layer_occ)
self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)
self.context_networks_occ = OccContextNetwork(self.dim_corr + 32 + 1 + 448 + 1)
self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}
initialize_msra(self.modules())
def forward(self, input_dict):
x1_raw = input_dict['input1']
x2_raw = input_dict['input2']
_, _, height_im, width_im = x1_raw.size()
# on the bottom level are original images
x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
# outputs
output_dict = {}
flows = []
occs = []
# init
b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
init_dtype = x1_pyramid[0].dtype
init_device = x1_pyramid[0].device
flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
# warping
if l == 0:
x2_warp = x2
x1_warp = x1
else:
flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
occ_f = upsample2d_as(occ_f, x1, mode="bilinear")
occ_b = upsample2d_as(occ_b, x2, mode="bilinear")
x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
# correlation
out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)
out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)
out_corr_relu_f = self.leakyRELU(out_corr_f)
out_corr_relu_b = self.leakyRELU(out_corr_b)
# flow estimator
if l == 0:
x_intm_f, flow_f = self.flow_estimators[l](out_corr_relu_f)
x_intm_b, flow_b = self.flow_estimators[l](out_corr_relu_b)
x_intm_occ_f, occ_f = self.occ_estimators[l](out_corr_relu_f)
x_intm_occ_b, occ_b = self.occ_estimators[l](out_corr_relu_b)
else:
x_intm_f, flow_f = self.flow_estimators[l](torch.cat([out_corr_relu_f, x1, flow_f], dim=1))
x_intm_b, flow_b = self.flow_estimators[l](torch.cat([out_corr_relu_b, x2, flow_b], dim=1))
x_intm_occ_f, occ_f = self.occ_estimators[l](torch.cat([out_corr_relu_f, x1, occ_f], dim=1))
x_intm_occ_b, occ_b = self.occ_estimators[l](torch.cat([out_corr_relu_b, x1, occ_b], dim=1))
# upsampling or post-processing
if l != self.output_level:
flows.append([flow_f, flow_b])
occs.append([occ_f, occ_b])
else:
flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))
flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))
flow_f = flow_f + flow_fine_f
flow_b = flow_b + flow_fine_b
flows.append([flow_f, flow_b])
occ_fine_f = self.context_networks_occ(torch.cat([x_intm_occ_f, occ_f], dim=1))
occ_fine_b = self.context_networks_occ(torch.cat([x_intm_occ_b, occ_b], dim=1))
occ_f = occ_f + occ_fine_f
occ_b = occ_b + occ_fine_b
occs.append([occ_f, occ_b])
break
output_dict['flow'] = flows
output_dict['occ'] = occs
if self.training:
return output_dict
else:
output_dict_eval = {}
output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear")
return output_dict_eval
================================================
FILE: optim/__init__.py
================================================
import torch
import sys
from tools import module_classes_to_dict
# ------------------------------------------------------------------------------------
# Export PyTorch optimizer
# ------------------------------------------------------------------------------------
_this = sys.modules[__name__]
_optimizer_classes = module_classes_to_dict(torch.optim, exclude_classes="Optimizer")
for name, constructor in _optimizer_classes.items():
setattr(_this, name, constructor)
__all__ = _optimizer_classes.keys()
================================================
FILE: runtime.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import numpy as np
import colorama
import logging
import logger
import tools
from tools import MovingAverage
import collections
import scipy.misc
import torch
import torch.nn as nn
import os
# for evaluation
from utils.flow import flow_to_png, flow_to_png_middlebury
from utils.flow import write_flow, write_flow_png
# --------------------------------------------------------------------------------
# Exponential moving average smoothing factor for speed estimates
# Ranges from 0 (average speed) to 1 (current/instantaneous speed) [default: 0.3].
# --------------------------------------------------------------------------------
TQDM_SMOOTHING = 0
# -------------------------------------------------------------------------------------------
# Magic progressbar for inputs of type 'iterable'
# -------------------------------------------------------------------------------------------
def create_progressbar(iterable,
desc="",
train=False,
unit="it",
initial=0,
offset=0,
invert_iterations=False,
logging_on_update=False,
logging_on_close=True,
postfix=False):
# ---------------------------------------------------------------
# Pick colors
# ---------------------------------------------------------------
reset = colorama.Style.RESET_ALL
bright = colorama.Style.BRIGHT
cyan = colorama.Fore.CYAN
dim = colorama.Style.DIM
green = colorama.Fore.GREEN
# ---------------------------------------------------------------
# Specify progressbar layout:
# l_bar, bar, r_bar, n, n_fmt, total, total_fmt, percentage,
# rate, rate_fmt, rate_noinv, rate_noinv_fmt, rate_inv,
# rate_inv_fmt, elapsed, remaining, desc, postfix.
# ---------------------------------------------------------------
bar_format = ""
bar_format += "%s==>%s%s {desc}:%s " % (cyan, reset, bright, reset) # description
bar_format += "{percentage:3.0f}%" # percentage
bar_format += "%s|{bar}|%s " % (dim, reset) # bar
bar_format += " {n_fmt}/{total_fmt} " # i/n counter
bar_format += "{elapsed}<{remaining}" # eta
if invert_iterations:
bar_format += " {rate_inv_fmt} " # iteration timings
else:
bar_format += " {rate_noinv_fmt} "
bar_format += "%s{postfix}%s" % (green, reset) # postfix
# ---------------------------------------------------------------
# Specify TQDM arguments
# ---------------------------------------------------------------
tqdm_args = {
"iterable": iterable,
"desc": desc, # Prefix for the progress bar
"total": len(iterable), # The number of expected iterations
"leave": True, # Leave progress bar when done
"miniters": 1 if train else None, # Minimum display update interval in iterations
"unit": unit, # String be used to define the unit of each iteration
"initial": initial, # The initial counter value.
"dynamic_ncols": True, # Allow window resizes
"smoothing": TQDM_SMOOTHING, # Moving average smoothing factor for speed estimates
"bar_format": bar_format, # Specify a custom bar string formatting
"position": offset, # Specify vertical line offset
"ascii": True,
"logging_on_update": logging_on_update,
"logging_on_close": logging_on_close
}
return tools.tqdm_with_logging(**tqdm_args)
def tensor2float_dict(tensor_dict):
return {key: tensor.item() for key, tensor in tensor_dict.items()}
def format_moving_averages_as_progress_dict(moving_averages_dict={},
moving_averages_postfix="avg"):
progress_dict = collections.OrderedDict([
(key + moving_averages_postfix, "%1.4f" % moving_averages_dict[key].mean())
for key in sorted(moving_averages_dict.keys())
])
return progress_dict
def format_learning_rate(lr):
if np.isscalar(lr):
return "{}".format(lr)
else:
return "{}".format(str(lr[0]) if len(lr) == 1 else lr)
class TrainingEpoch:
def __init__(self,
args,
model_and_loss,
loader,
optimizer,
augmentation=None,
add_progress_stats={},
desc="Training Epoch"):
self._args = args
self._desc = desc
self._loader = loader
self._model_and_loss = model_and_loss
self._optimizer = optimizer
self._augmentation = augmentation
self._add_progress_stats = add_progress_stats
def _step(self, example_dict):
# -------------------------------------------------------------
# Get input and target tensor keys
# -------------------------------------------------------------
input_keys = list(filter(lambda x: "input" in x, example_dict.keys()))
target_keys = list(filter(lambda x: "target" in x, example_dict.keys()))
tensor_keys = input_keys + target_keys
# -------------------------------------------------------------
# Possibly transfer to Cuda
# -------------------------------------------------------------
if self._args.cuda:
for key, value in example_dict.items():
if key in tensor_keys:
example_dict[key] = value.cuda(non_blocking=False)
# -------------------------------------------------------------
# Optionally perform augmentations
# -------------------------------------------------------------
if self._augmentation is not None:
with torch.no_grad():
example_dict = self._augmentation(example_dict)
# -------------------------------------------------------------
# Convert inputs/targets to variables that require gradients
# -------------------------------------------------------------
for key, tensor in example_dict.items():
if key in input_keys:
example_dict[key] = tensor.requires_grad_(True)
elif key in target_keys:
example_dict[key] = tensor.requires_grad_(False)
# -------------------------------------------------------------
# Extract batch size from first input
# -------------------------------------------------------------
batch_size = example_dict["input1"].size()[0]
# -------------------------------------------------------------
# Reset gradients
# -------------------------------------------------------------
self._optimizer.zero_grad()
# -------------------------------------------------------------
# Run forward pass to get losses and outputs.
# -------------------------------------------------------------
loss_dict, output_dict = self._model_and_loss(example_dict)
# -------------------------------------------------------------
# Check total_loss for NaNs
# -------------------------------------------------------------
training_loss = loss_dict[self._args.training_key]
assert (not np.isnan(training_loss.item())), "training_loss is NaN"
# -------------------------------------------------------------
# Back propagation
# -------------------------------------------------------------
training_loss.backward()
self._optimizer.step()
# -------------------------------------------------------------
# Return success flag, loss and output dictionary
# -------------------------------------------------------------
return loss_dict, output_dict, batch_size
def run(self, offset=0):
# ---------------------------------------
# Tell model that we want to train
# ---------------------------------------
self._model_and_loss.train()
# ---------------------------------------
# Keep track of moving averages
# ---------------------------------------
moving_averages_dict = None
# ---------------------------------------
# Progress bar arguments
# ---------------------------------------
progressbar_args = {
"iterable": self._loader,
"desc": self._desc,
"train": True,
"offset": offset,
"logging_on_update": False,
"logging_on_close": True,
"postfix": True
}
# ---------------------------------------
# Perform training steps
# ---------------------------------------
with create_progressbar(**progressbar_args) as progress:
for example_dict in progress:
# perform step
loss_dict_per_step, output_dict, batch_size = self._step(example_dict)
# convert
loss_dict_per_step = tensor2float_dict(loss_dict_per_step)
# --------------------------------------------------------
# Possibly initialize moving averages
# --------------------------------------------------------
if moving_averages_dict is None:
moving_averages_dict = {
key: MovingAverage() for key in loss_dict_per_step.keys()
}
# --------------------------------------------------------
# Add moving averages
# --------------------------------------------------------
for key, loss in loss_dict_per_step.items():
moving_averages_dict[key].add_average(loss, addcount=batch_size)
# view statistics in progress bar
progress_stats = format_moving_averages_as_progress_dict(
moving_averages_dict=moving_averages_dict,
moving_averages_postfix="_ema")
progress.set_postfix(progress_stats)
# -------------------------------------------------------------
# Return loss and output dictionary
# -------------------------------------------------------------
ema_loss_dict = { key: ma.mean() for key, ma in moving_averages_dict.items() }
return ema_loss_dict
class EvaluationEpoch:
def __init__(self,
args,
model_and_loss,
loader,
augmentation=None,
add_progress_stats={},
desc="Evaluation Epoch"):
self._args = args
self._desc = desc
self._loader = loader
self._model_and_loss = model_and_loss
self._add_progress_stats = add_progress_stats
self._augmentation = augmentation
self._save_output = False
if self._args.save_result_img or self._args.save_result_flo or self._args.save_result_png:
self._save_output = True
def save_outputs(self, example_dict, output_dict):
# save occ
save_root_img = self._args.save + '/img/'
save_root_flo = self._args.save + '/flo/'
if self._args.save_result_bidirection:
flow_f = output_dict["flow"].data.cpu().numpy()
flow_b = output_dict["flow_b"].data.cpu().numpy()
b_size = output_dict["flow"].data.size(0)
else:
flow_f = output_dict["flow"].data.cpu().numpy()
b_size = output_dict["flow"].data.size(0)
if self._args.save_result_occ:
if self._args.save_result_bidirection:
output_occ = np.round(
nn.Sigmoid()(output_dict["occ"]).expand(-1, 3, -1, -1).data.cpu().numpy().transpose(
[0, 2, 3, 1])) * 255
output_occ_b = np.round(
nn.Sigmoid()(output_dict["occ_b"]).expand(-1, 3, -1, -1).data.cpu().numpy().transpose(
[0, 2, 3, 1])) * 255
else:
output_occ = np.round(
nn.Sigmoid()(output_dict["occ"]).expand(-1, 3, -1, -1).data.cpu().numpy().transpose(
[0, 2, 3, 1])) * 255
# file names
file_names_img = []
file_names_flo = []
for ii in range(0, b_size):
if "basedir" in example_dict.keys():
file_name_img = save_root_img + example_dict["basedir"][ii] + '/' + str(example_dict["basename"][ii])
file_name_flo = save_root_flo + example_dict["basedir"][ii] + '/' + str(example_dict["basename"][ii])
file_names_img.append(file_name_img)
file_names_flo.append(file_name_flo)
else:
file_name_img = save_root_img + '/' + str(example_dict["basename"][ii])
file_name_flo = save_root_flo + '/' + str(example_dict["basename"][ii])
file_names_img.append(file_name_img)
file_names_flo.append(file_name_flo)
directory_img = os.path.dirname(file_name_img)
if not os.path.exists(directory_img):
os.makedirs(directory_img)
directory_flo = os.path.dirname(file_name_flo)
if not os.path.exists(directory_flo):
os.makedirs(directory_flo)
if self._args.save_result_img:
for ii in range(0, b_size):
if self._args.save_result_occ:
file_name_occ = file_names_img[ii] + '_occ.png'
scipy.misc.imsave(file_name_occ, output_occ[ii])
if self._args.save_result_bidirection:
scipy.misc.imsave(file_names_img[ii] + '_occ_b.png', output_occ_b[ii])
# flow vis
flow_f_rgb = flow_to_png_middlebury(flow_f[ii, ...])
file_name_flo_vis = file_names_img[ii] + '_flow.png'
scipy.misc.imsave(file_name_flo_vis, flow_f_rgb)
if self._args.save_result_bidirection:
flow_b_rgb = flow_to_png_middlebury(flow_b[ii, ...])
file_name_flo_vis = file_names_img[ii] + '_flow_b.png'
scipy.misc.imsave(file_name_flo_vis, flow_b_rgb)
if self._args.save_result_flo or self._args.save_result_png:
for ii in range(0, b_size):
if self._args.save_result_flo:
file_name = file_names_flo[ii] + '.flo'
write_flow(file_name, flow_f[ii, ...].swapaxes(0, 1).swapaxes(1, 2))
if self._args.save_result_png:
file_name = file_names_flo[ii] + '.png'
write_flow_png(file_name, flow_f[ii, ...].swapaxes(0, 1).swapaxes(1, 2))
def _step(self, example_dict):
# -------------------------------------------------------------
# Get input and target tensor keys
# -------------------------------------------------------------
input_keys = list(filter(lambda x: "input" in x, example_dict.keys()))
target_keys = list(filter(lambda x: "target" in x, example_dict.keys()))
tensor_keys = input_keys + target_keys
# -------------------------------------------------------------
# Possibly transfer to Cuda
# -------------------------------------------------------------
if self._args.cuda:
for key, value in example_dict.items():
if key in tensor_keys:
example_dict[key] = value.cuda(non_blocking=False)
# -------------------------------------------------------------
# Optionally perform augmentations
# -------------------------------------------------------------
if self._augmentation is not None:
example_dict = self._augmentation(example_dict)
# -------------------------------------------------------------
# Extract batch size from first input
# -------------------------------------------------------------
batch_size = example_dict["input1"].size()[0]
# -------------------------------------------------------------
# Run forward pass to get losses and outputs.
# -------------------------------------------------------------
loss_dict, output_dict = self._model_and_loss(example_dict)
# -------------------------------------------------------------
# Return loss and output dictionary
# -------------------------------------------------------------
return loss_dict, output_dict, batch_size
def run(self, offset=0):
with torch.no_grad():
# ---------------------------------------
# Tell model that we want to evaluate
# ---------------------------------------
self._model_and_loss.eval()
# ---------------------------------------
# Keep track of moving averages
# ---------------------------------------
moving_averages_dict = None
# ---------------------------------------
# Progress bar arguments
# ---------------------------------------
progressbar_args = {
"iterable": self._loader,
"desc": self._desc,
"train": False,
"offset": offset,
"logging_on_update": False,
"logging_on_close": True,
"postfix": True
}
# ---------------------------------------
# Perform evaluation steps
# ---------------------------------------
with create_progressbar(**progressbar_args) as progress:
for example_dict in progress:
# ---------------------------------------
# Perform forward evaluation step
# ---------------------------------------
loss_dict_per_step, output_dict, batch_size = self._step(example_dict)
# --------------------------------------------------------
# Save results
# --------------------------------------------------------
if self._save_output:
self.save_outputs(example_dict, output_dict)
# ---------------------------------------
# Convert loss dictionary to float
# ---------------------------------------
loss_dict_per_step = tensor2float_dict(loss_dict_per_step)
# --------------------------------------------------------
# Possibly initialize moving averages
# --------------------------------------------------------
if moving_averages_dict is None:
moving_averages_dict = {
key: MovingAverage() for key in loss_dict_per_step.keys()
}
# --------------------------------------------------------
# Add moving averages
# --------------------------------------------------------
for key, loss in loss_dict_per_step.items():
moving_averages_dict[key].add_average(loss, addcount=batch_size)
# view statistics in progress bar
progress_stats = format_moving_averages_as_progress_dict(
moving_averages_dict=moving_averages_dict,
moving_averages_postfix="_avg")
progress.set_postfix(progress_stats)
# -------------------------------------------------------------
# Record average losses
# -------------------------------------------------------------
avg_loss_dict = { key: ma.mean() for key, ma in moving_averages_dict.items() }
# -------------------------------------------------------------
# Return average losses and output dictionary
# -------------------------------------------------------------
return avg_loss_dict
def exec_runtime(args,
checkpoint_saver,
model_and_loss,
optimizer,
lr_scheduler,
train_loader,
validation_loader,
inference_loader,
training_augmentation,
validation_augmentation):
# ----------------------------------------------------------------------------------------------
# Validation schedulers are a bit special:
# They want to be called with a validation loss..
# ----------------------------------------------------------------------------------------------
validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau")
# --------------------------------------------------------
# Log some runtime info
# --------------------------------------------------------
with logger.LoggingBlock("Runtime", emph=True):
logging.info("start_epoch: %i" % args.start_epoch)
logging.info("total_epochs: %i" % args.total_epochs)
# ---------------------------------------
# Total progress bar arguments
# ---------------------------------------
progressbar_args = {
"desc": "Progress",
"initial": args.start_epoch - 1,
"invert_iterations": True,
"iterable": range(1, args.total_epochs + 1),
"logging_on_close": True,
"logging_on_update": True,
"postfix": False,
"unit": "ep"
}
# --------------------------------------------------------
# Total progress bar
# --------------------------------------------------------
print(''), logging.logbook('')
total_progress = create_progressbar(**progressbar_args)
print("\n")
# --------------------------------------------------------
# Remember validation loss
# --------------------------------------------------------
best_validation_loss = float("inf") if args.validation_key_minimize else -float("inf")
store_as_best = False
for epoch in range(args.start_epoch, args.total_epochs + 1):
with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs), emph=True):
# Always report learning rate
if lr_scheduler is not None:
logging.info("lr: %s" % format_learning_rate(lr_scheduler.get_last_lr()))
# -------------------------------------------
# Create and run a training epoch
# -------------------------------------------
if train_loader is not None:
avg_loss_dict = TrainingEpoch(
args,
desc=" Train",
model_and_loss=model_and_loss,
optimizer=optimizer,
loader=train_loader,
augmentation=training_augmentation).run()
# -------------------------------------------
# Create and run a validation epoch
# -------------------------------------------
if validation_loader is not None:
# ---------------------------------------------------
# Construct holistic recorder for epoch
# ---------------------------------------------------
avg_loss_dict = EvaluationEpoch(
args,
desc="Validate",
model_and_loss=model_and_loss,
loader=validation_loader,
augmentation=validation_augmentation).run()
# ----------------------------------------------------------------
# Evaluate whether this is the best validation_loss
# ----------------------------------------------------------------
validation_loss = avg_loss_dict[args.validation_key]
if args.validation_key_minimize:
store_as_best = validation_loss < best_validation_loss
else:
store_as_best = validation_loss > best_validation_loss
if store_as_best:
best_validation_loss = validation_loss
# Update standard learning scheduler
if lr_scheduler is not None:
lr_scheduler.step()
# ----------------------------------------------------------------
# Also show best loss on total_progress
# ----------------------------------------------------------------
total_progress_stats = {
"best_" + args.validation_key + "_avg": "%1.4f" % best_validation_loss
}
total_progress.set_postfix(total_progress_stats)
# ----------------------------------------------------------------
# Bump total progress
# ----------------------------------------------------------------
total_progress.update()
print('')
# ----------------------------------------------------------------
# Store checkpoint
# ----------------------------------------------------------------
if checkpoint_saver is not None:
checkpoint_saver.save_latest(
directory=args.save,
model_and_loss=model_and_loss,
stats_dict=dict(avg_loss_dict, epoch=epoch),
store_as_best=store_as_best)
# ----------------------------------------------------------------
# Vertical space between epochs
# ----------------------------------------------------------------
print(''), logging.logbook('')
# ----------------------------------------------------------------
# Finish
# ----------------------------------------------------------------
total_progress.close()
logging.info("Finished.")
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_flyingchairsOcc/checkpoint_best.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_flyingchairsOcc/checkpoint_latest.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_kitti/checkpoint_best.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_kitti/checkpoint_latest.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_sintel/checkpoint_best.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_sintel/checkpoint_latest.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_things3d/checkpoint_best.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/IRR-PWC_things3d/checkpoint_latest.ckpt
================================================
[File too large to display: 24.3 MB]
================================================
FILE: saved_check_point/pwcnet/PWCNet/checkpoint_best.ckpt
================================================
[File too large to display: 33.0 MB]
================================================
FILE: saved_check_point/pwcnet/PWCNet-irr/checkpoint_best.ckpt
================================================
[File too large to display: 12.8 MB]
================================================
FILE: scripts/IRR-FlowNet_flyingChairsOcc.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=IRR_FlowNet
EVAL_LOSS=MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample
CHECKPOINT=None
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[54, 72, 90]" \
--model=$MODEL \
--num_workers=4 \
--num_iters=2 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=108 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/IRR-PWC_flyingChairsOcc.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=IRR_PWC
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample
CHECKPOINT=None
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[54, 72, 90]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=108 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/IRR-PWC_kitti_train.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
KITTI_HOME=(YOUR PATH)/KITTI_flow/
# model and checkpoint
MODEL=IRR_PWC
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI
CHECKPOINT="saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt"
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=1 \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[730, 984, 1238, 1365, 1397, 1429, 1556, 1683, 1810, 1937]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=3e-05 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=160 \
--total_epochs=2064 \
--training_augmentation=RandomAffineFlowOccKITTI \
--training_augmentation_crop="[320,896]" \
--training_dataset=KittiCombTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$KITTI_HOME \
--training_dataset_preprocessing_crop=True \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=KittiCombVal \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$KITTI_HOME \
--validation_dataset_preprocessing_crop=False \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/IRR-PWC_kitti_train_full.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
KITTI_HOME=(YOUR PATH)/KITTI_flow/
# model and checkpoint
MODEL=IRR_PWC
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI
CHECKPOINT="saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt"
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=1 \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[616, 819, 1022, 1123, 1149, 1174, 1276, 1377, 1479, 1580]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=3e-05 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=160 \
--total_epochs=710 \
--training_augmentation=RandomAffineFlowOccKITTI \
--training_augmentation_crop="[320,896]" \
--training_dataset=KittiCombFull \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$KITTI_HOME \
--training_dataset_preprocessing_crop=True \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=KittiCombVal \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$KITTI_HOME \
--validation_dataset_preprocessing_crop=False \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/IRR-PWC_sintel_train.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_PWC
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel
CHECKPOINT="saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt"
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[258, 302, 346, 368, 374, 379, 401, 423, 445, 467]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1.5e-05 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=160 \
--total_epochs=489 \
--training_augmentation=RandomAffineFlowOccSintel \
--training_augmentation_crop="[384,768]" \
--training_dataset=SintelTrainingCombTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$SINTEL_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=SintelTrainingCombValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[687, 775, 863, 908, 919, 930, 974, 1018, 1062, 1106]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1e-05 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=490 \
--total_epochs=1150 \
--training_augmentation=RandomAffineFlowOccSintel \
--training_augmentation_crop="[384,768]" \
--training_dataset=SintelTrainingFinalTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$SINTEL_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=SintelTrainingFinalValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/IRR-PWC_sintel_train_full.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_PWC
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel
CHECKPOINT="saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt"
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[245, 284, 322, 342, 346, 351, 370, 390, 409, 428]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1.5e-05 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=160 \
--total_epochs=447 \
--training_augmentation=RandomAffineFlowOccSintel \
--training_augmentation_crop="[384,768]" \
--training_dataset=SintelTrainingCombFull \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$SINTEL_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=SintelTrainingCombValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[620, 697, 774, 812, 822, 831, 870, 908, 947, 985]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1e-05 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=448 \
--total_epochs=591 \
--training_augmentation=RandomAffineFlowOccSintel \
--training_augmentation_crop="[384,768]" \
--training_dataset=SintelTrainingFinalFull \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$SINTEL_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=SintelTrainingFinalValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/IRR-PWC_things3d.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGTHINGS_HOME=(YOUR PATH)/things3d/FlyingThings3D_subset/
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_PWC
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample
CHECKPOINT="saved_check_point/IRR-PWC_flyingchairsOcc/checkpoint_latest.ckpt"
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[128, 139, 149]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1e-5 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--start_epoch=109 \
--total_epochs=159 \
--training_augmentation=RandomAffineFlowOcc \
--training_augmentation_crop="[384,768]" \
--training_dataset=FlyingThings3dCleanTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGTHINGS_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/flownet1s.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=FlowNet1S
EVAL_LOSS=MultiScaleEPE_FlowNet
CHECKPOINT=None
SIZE_OF_BATCH=8
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[108, 144, 180]" \
--model=$MODEL \
--num_workers=4 \
--num_iters=1 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=216 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/flownet1s_irr1.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=FlowNet1S_irr
EVAL_LOSS=MultiScaleEPE_FlowNet_IRR
CHECKPOINT=None
SIZE_OF_BATCH=8
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[108, 144, 180]" \
--model=$MODEL \
--num_workers=4 \
--num_iters=1 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=216 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/flownet1s_irr2.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=FlowNet1S_irr
EVAL_LOSS=MultiScaleEPE_FlowNet_IRR
CHECKPOINT=None
SIZE_OF_BATCH=4
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[54, 72, 90]" \
--model=$MODEL \
--num_workers=4 \
--num_iters=2 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=108 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/pwcnet.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=PWCNet
EVAL_LOSS=MultiScaleEPE_PWC
CHECKPOINT=None
SIZE_OF_BATCH=8
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[108, 144, 180]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=216 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/pwcnet_irr.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="experiments"
# datasets
FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data
# model and checkpoint
MODEL=PWCNet_irr
EVAL_LOSS=MultiScaleEPE_PWC
CHECKPOINT=None
SIZE_OF_BATCH=8
# save path
TIME=$(date +"%Y%m%d-%H%M%S")
SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-$TIME"
# training configuration
python ../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--lr_scheduler=MultiStepLR \
--lr_scheduler_gamma=0.5 \
--lr_scheduler_milestones="[108, 144, 180]" \
--model=$MODEL \
--num_workers=4 \
--optimizer=Adam \
--optimizer_lr=1e-4 \
--optimizer_weight_decay=4e-4 \
--save=$SAVE_PATH \
--total_epochs=216 \
--training_augmentation=RandomAffineFlowOcc \
--training_dataset=FlyingChairsOccTrain \
--training_dataset_photometric_augmentations=True \
--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--training_key=total_loss \
--training_loss=$EVAL_LOSS \
--validation_dataset=FlyingChairsOccValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/IRR-FlowNet_flyingChairs.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/flownet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_FlowNet
CHECKPOINT="$EXPERIMENTS_HOME/IRR-FlowNet_flyingChairs/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--num_iters=2 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/IRR-PWC_flyingChairs.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/pwcnet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_PWC
CHECKPOINT="$EXPERIMENTS_HOME/IRR-PWC_flyingchairsOcc/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/IRR-PWC_kitti.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/pwcnet"
# datasets
KITTI_HOME=(YOUR PATH)/KITTI_flow/
# model and checkpoint
MODEL=IRR_PWC
CHECKPOINT="$EXPERIMENTS_HOME/IRR-PWC_kitti/checkpoint_latest.ckpt"
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI
SIZE_OF_BATCH=1
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--save=$SAVE_PATH \
--validation_dataset=KittiCombVal \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$KITTI_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/IRR-PWC_sintel.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/pwcnet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_PWC
CHECKPOINT="$EXPERIMENTS_HOME/IRR-PWC_sintel/checkpoint_latest.ckpt"
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingFinalValid \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/IRR-PWC_things3d.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/pwcnet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=IRR_PWC
CHECKPOINT="$EXPERIMENTS_HOME/IRR-PWC_things3d/checkpoint_latest.ckpt"
EVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/flownet1s.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/flownet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=FlowNet1S
CHECKPOINT="$EXPERIMENTS_HOME/FlowNet1S/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_FlowNet
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--num_iters=1 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/flownet1s_irr1.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/flownet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=FlowNet1S_irr
CHECKPOINT="$EXPERIMENTS_HOME/FlowNet1S-irr1/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_FlowNet_IRR
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--num_iters=1 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/flownet1s_irr2.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/flownet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=FlowNet1S_irr
CHECKPOINT="$EXPERIMENTS_HOME/FlowNet1S-irr2/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_FlowNet_IRR
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--num_iters=2 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/pwcnet.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/pwcnet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=PWCNet
CHECKPOINT="$EXPERIMENTS_HOME/PWCNet/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_PWC
SIZE_OF_BATCH=1
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: scripts/validation/pwcnet_irr.sh
================================================
#!/bin/bash
# experiments and datasets meta
EXPERIMENTS_HOME="saved_check_point/pwcnet"
# datasets
SINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/
# model and checkpoint
MODEL=PWCNet_irr
CHECKPOINT="$EXPERIMENTS_HOME/PWCNet-irr/checkpoint_best.ckpt"
EVAL_LOSS=MultiScaleEPE_PWC
SIZE_OF_BATCH=4
# validate clean configuration
SAVE_PATH="$EXPERIMENTS_HOME/eval_temp/$MODEL"
python ../../main.py \
--batch_size=$SIZE_OF_BATCH \
--batch_size_val=$SIZE_OF_BATCH \
--checkpoint=$CHECKPOINT \
--evaluation=True \
--model=$MODEL \
--num_workers=4 \
--save=$SAVE_PATH \
--validation_dataset=SintelTrainingCleanFull \
--validation_dataset_photometric_augmentations=False \
--validation_dataset_root=$SINTEL_HOME \
--validation_key=epe \
--validation_loss=$EVAL_LOSS
================================================
FILE: tools.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import os
import socket
import re
from pytz import timezone
from datetime import datetime
import fnmatch
import itertools
import argparse
import sys
import six
import unicodedata
import json
import inspect
import tqdm
import logging
import torch
import ast
import numpy as np
def x2module(module_or_data_parallel):
if isinstance(module_or_data_parallel, torch.nn.DataParallel):
return module_or_data_parallel.module
else:
return module_or_data_parallel
# ----------------------------------------------------------------------------------------
# Comprehensively adds a new logging level to the `logging` module and the
# currently configured logging class.
# e.g. addLoggingLevel('TRACE', logging.DEBUG - 5)
# ----------------------------------------------------------------------------------------
def addLoggingLevel(level_name, level_num, method_name=None):
if not method_name:
method_name = level_name.lower()
if hasattr(logging, level_name):
raise AttributeError('{} already defined in logging module'.format(level_name))
if hasattr(logging, method_name):
raise AttributeError('{} already defined in logging module'.format(method_name))
if hasattr(logging.getLoggerClass(), method_name):
raise AttributeError('{} already defined in logger class'.format(method_name))
# This method was inspired by the answers to Stack Overflow post
# http://stackoverflow.com/q/2183233/2988730, especially
# http://stackoverflow.com/a/13638084/2988730
def logForLevel(self, message, *args, **kwargs):
if self.isEnabledFor(level_num):
self._log(level_num, message, args, **kwargs)
def logToRoot(message, *args, **kwargs):
logging.log(level_num, message, *args, **kwargs)
logging.addLevelName(level_num, level_name)
setattr(logging, level_name, level_num)
setattr(logging.getLoggerClass(), method_name, logForLevel)
setattr(logging, method_name, logToRoot)
# -------------------------------------------------------------------------------------------------
# Looks for sub arguments in the argument structure.
# Retrieve sub arguments for modules such as optimizer_*
# -------------------------------------------------------------------------------------------------
def kwargs_from_args(args, name, exclude=[]):
if isinstance(exclude, str):
exclude = [exclude]
exclude += ["class"]
args_dict = vars(args)
name += "_"
subargs_dict = {
key[len(name):]: value for key, value in args_dict.items()
if name in key and all([key != name + x for x in exclude])
}
return subargs_dict
# -------------------------------------------------------------------------------------------------
# Create class instance from kwargs dictionary.
# Filters out keys that not in the constructor
# -------------------------------------------------------------------------------------------------
def instance_from_kwargs(class_constructor, kwargs):
argspec = inspect.getargspec(class_constructor.__init__)
full_args = argspec.args
filtered_args = dict([(k,v) for k,v in kwargs.items() if k in full_args])
instance = class_constructor(**filtered_args)
return instance
def module_classes_to_dict(module, include_classes="*", exclude_classes=()):
# -------------------------------------------------------------------------
# If arguments are strings, convert them to a list
# -------------------------------------------------------------------------
if include_classes is not None:
if isinstance(include_classes, str):
include_classes = [include_classes]
if exclude_classes is not None:
if isinstance(exclude_classes, str):
exclude_classes = [exclude_classes]
# -------------------------------------------------------------------------
# Obtain dictionary from given module
# -------------------------------------------------------------------------
item_dict = dict([(name, getattr(module, name)) for name in dir(module)])
# -------------------------------------------------------------------------
# Filter classes
# -------------------------------------------------------------------------
item_dict = dict([
(name,value) for name, value in item_dict.items() if inspect.isclass(getattr(module, name))
])
filtered_keys = filter_list_of_strings(
item_dict.keys(), include=include_classes, exclude=exclude_classes)
# -------------------------------------------------------------------------
# Construct dictionary from matched results
# -------------------------------------------------------------------------
result_dict = dict([(name, value) for name, value in item_dict.items() if name in filtered_keys])
return result_dict
def ensure_dir(file_path):
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
def search_and_replace(string, regex, replace):
while True:
match = re.search(regex, string)
if match:
string = string.replace(match.group(0), replace)
else:
break
return string
def hostname():
name = socket.gethostname()
n = name.find('.')
if n > 0:
name = name[:n]
return name
def get_filenames(directory, match='*.*', not_match=()):
if match is not None:
if isinstance(match, str):
match = [match]
if not_match is not None:
if isinstance(not_match, str):
not_match = [not_match]
result = []
for dirpath, _, filenames in os.walk(directory):
filtered_matches = list(itertools.chain.from_iterable(
[fnmatch.filter(filenames, x) for x in match]))
filtered_nomatch = list(itertools.chain.from_iterable(
[fnmatch.filter(filenames, x) for x in not_match]))
matched = list(set(filtered_matches) - set(filtered_nomatch))
result += [os.path.join(dirpath, x) for x in matched]
return result
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def str2str_or_none(v):
if v.lower() == "none":
return None
return v
def str2dict(v):
return ast.literal_eval(v)
def str2intlist(v):
return [int(x.strip()) for x in v.strip()[1:-1].split(',')]
def str2list(v):
return [str(x.strip()) for x in v.strip()[1:-1].split(',')]
def read_json(filename):
def _convert_from_unicode(data):
new_data = dict()
for name, value in six.iteritems(data):
if isinstance(name, six.string_types):
name = unicodedata.normalize('NFKD', name).encode(
'ascii', 'ignore')
if isinstance(value, six.string_types):
value = unicodedata.normalize('NFKD', value).encode(
'ascii', 'ignore')
if isinstance(value, dict):
value = _convert_from_unicode(value)
new_data[name] = value
return new_data
output_dict = None
with open(filename, "r") as f:
lines = f.readlines()
try:
output_dict = json.loads(''.join(lines), encoding='utf-8')
except:
raise ValueError('Could not read %s. %s' % (filename, sys.exc_info()[1]))
output_dict = _convert_from_unicode(output_dict)
return output_dict
def write_json(data_dict, filename):
with open(filename, "w") as file:
json.dump(data_dict, file)
def datestr():
pacific = timezone('US/Pacific')
now = datetime.now(pacific)
return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute)
def filter_list_of_strings(lst, include="*", exclude=()):
filtered_matches = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in include]))
filtered_nomatch = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in exclude]))
matched = list(set(filtered_matches) - set(filtered_nomatch))
return matched
# ----------------------------------------------------------------------------
# Writes all pairs to a filename for book keeping
# Either .txt or .json
# ----------------------------------------------------------------------------
def write_dictionary_to_file(arguments_dict, filename):
# ensure dir
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
# check for json extension
ext = os.path.splitext(filename)[1]
if ext == ".json":
def replace_quotes(x):
return x.replace("\'", "\"")
with open(filename, 'w') as file:
file.write("{\n")
for i, (key, value) in enumerate(arguments_dict):
if isinstance(value, tuple):
value = list(value)
if value is None:
file.write(" \"%s\": null" % key)
elif isinstance(value, str):
value = value.replace("\'", "\"")
file.write(" \"%s\": \"%s\"" % (key, replace_quotes(str( value))))
elif isinstance(value, bool):
file.write(" \"%s\": %s" % (key, str(value).lower()))
else:
file.write(" \"%s\": %s" % (key, replace_quotes(str(value))))
if i < len(arguments_dict) - 1:
file.write(',\n')
else:
file.write('\n')
file.write("}\n")
else:
with open(filename, 'w') as file:
for key, value in arguments_dict:
file.write('%s: %s\n' % (key, value))
class MovingAverage:
postfix = "avg"
def __init__(self):
self._sum = 0.0
self._count = 0
def add_value(self, sigma, addcount=1):
self._sum += sigma
self._count += addcount
def add_average(self, avg, addcount):
self._sum += avg*addcount
self._count += addcount
def mean(self):
return self._sum / self._count
class ExponentialMovingAverage:
postfix = "ema"
def __init__(self, alpha=0.7):
self._weighted_sum = 0.0
self._weighted_count = 0
self._alpha = alpha
def add_value(self, sigma, addcount=1):
self._weighted_sum = sigma + (1.0 - self._alpha)*self._weighted_sum
self._weighted_count = 1 + (1.0 - self._alpha)*self._weighted_count
def add_average(self, avg, addcount):
self._weighted_sum = avg*addcount + (1.0 - self._alpha)*self._weighted_sum
self._weighted_count = addcount + (1.0 - self._alpha)*self._weighted_count
def mean(self):
return self._weighted_sum / self._weighted_count
# -----------------------------------------------------------------
# Subclass tqdm to achieve two things:
# 1) Output the progress bar into the logbook.
# 2) Remove the comma before {postfix} because it's annoying.
# -----------------------------------------------------------------
class TqdmToLogger(tqdm.tqdm):
def __init__(self, iterable=None, desc=None, total=None, leave=True,
file=None, ncols=None, mininterval=0.1,
maxinterval=10.0, miniters=None, ascii=None, disable=False,
unit='it', unit_scale=False, dynamic_ncols=False,
smoothing=0.3, bar_format=None, initial=0, position=None,
postfix=None,
logging_on_close=True,
logging_on_update=False):
super(TqdmToLogger, self).__init__(
iterable=iterable, desc=desc, total=total, leave=leave,
file=file, ncols=ncols, mininterval=mininterval,
maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable,
unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols,
smoothing=smoothing, bar_format=bar_format, initial=initial, position=position,
postfix=postfix)
self._logging_on_close = logging_on_close
self._logging_on_update = logging_on_update
self._closed = False
@staticmethod
def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False,
unit='it', unit_scale=False, rate=None, bar_format=None,
postfix=None, unit_divisor=1000):
meter = tqdm.tqdm.format_meter(
n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=prefix, ascii=ascii,
unit=unit, unit_scale=unit_scale, rate=rate, bar_format=bar_format,
postfix=postfix, unit_divisor=unit_divisor)
# get rid of that stupid comma before the postfix
if postfix is not None:
postfix_with_comma = ", %s" % postfix
meter = meter.replace(postfix_with_comma, postfix)
return meter
def update(self, n=1):
if self._logging_on_update:
msg = self.__repr__()
logging.logbook(msg)
return super(TqdmToLogger, self).update(n=n)
def close(self):
if self._logging_on_close and not self._closed:
msg = self.__repr__()
logging.logbook(msg)
self._closed = True
return super(TqdmToLogger, self).close()
def tqdm_with_logging(iterable=None, desc=None, total=None, leave=True,
ncols=None, mininterval=0.1,
maxinterval=10.0, miniters=None, ascii=None, disable=False,
unit="it", unit_scale=False, dynamic_ncols=False,
smoothing=0.3, bar_format=None, initial=0, position=None,
postfix=None,
logging_on_close=True,
logging_on_update=False):
return TqdmToLogger(
iterable=iterable, desc=desc, total=total, leave=leave,
ncols=ncols, mininterval=mininterval,
maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable,
unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols,
smoothing=smoothing, bar_format=bar_format, initial=initial, position=position,
postfix=postfix,
logging_on_close=logging_on_close,
logging_on_update=logging_on_update)
def cd_dotdot(path_or_filename):
return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), ".."))
def cd_dotdotdot(path_or_filename):
return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "../.."))
def cd_dotdotdotdot(path_or_filename):
return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "../../.."))
def tensor2numpy(tensor):
if isinstance(tensor, np.ndarray):
return tensor
else:
if isinstance(tensor, torch.autograd.Variable):
tensor = tensor.data
if tensor.dim() == 3:
return tensor.cpu().numpy().transpose([1,2,0])
else:
return tensor.cpu().numpy().transpose([0,2,3,1])
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/flow.py
================================================
from __future__ import absolute_import, division, print_function
import numpy as np
import png
import matplotlib.colors as cl
TAG_CHAR = np.array([202021.25], np.float32)
UNKNOWN_FLOW_THRESH = 1e7
def write_flow(filename, uv, v=None):
nBands = 2
if v is None:
assert (uv.ndim == 3)
assert (uv.shape[2] == 2)
u = uv[:, :, 0]
v = uv[:, :, 1]
else:
u = uv
assert (u.shape == v.shape)
height, width = u.shape
f = open(filename, 'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width * nBands))
tmp[:, np.arange(width) * 2] = u
tmp[:, np.arange(width) * 2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def write_flow_png(filename, uv, v=None, mask=None):
if v is None:
assert (uv.ndim == 3)
assert (uv.shape[2] == 2)
u = uv[:, :, 0]
v = uv[:, :, 1]
else:
u = uv
assert (u.shape == v.shape)
height_img, width_img = u.shape
if mask is None:
valid_mask = np.ones([height_img, width_img])
else:
valid_mask = mask
flow_u = np.clip((u * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16)
flow_v = np.clip((v * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16)
output = np.stack((flow_u, flow_v, valid_mask), axis=-1)
with open(filename, 'wb') as f:
writer = png.Writer(width=width_img, height=height_img, bitdepth=16)
writer.write(f, np.reshape(output, (-1, width_img*3)))
def flow_to_png(flow_map, max_value=None):
_, h, w = flow_map.shape
rgb_map = np.ones((h, w, 3)).astype(np.float32)
if max_value is not None:
normalized_flow_map = flow_map / max_value
else:
normalized_flow_map = flow_map / (np.abs(flow_map).max())
rgb_map[:, :, 0] += normalized_flow_map[0]
rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1])
rgb_map[:, :, 2] += normalized_flow_map[1]
return rgb_map.clip(0, 1)
def compute_color(u, v):
"""
compute optical flow color map
:param u: optical flow horizontal map
:param v: optical flow vertical map
:return: optical flow in color code
"""
[h, w] = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u ** 2 + v ** 2)
a = np.arctan2(-v, -u) / np.pi
fk = (a + 1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols + 1] = 1
f = fk - k0
for i in range(0, np.size(colorwheel, 1)):
tmp = colorwheel[:, i]
col0 = tmp[k0 - 1] / 255
col1 = tmp[k1 - 1] / 255
col = (1 - f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1 - rad[idx] * (1 - col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
return img
def make_color_wheel():
"""
Generate color wheel according Middlebury color code
:return: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
colorwheel[col:col + YG, 1] = 255
col += YG
# GC
colorwheel[col:col + GC, 1] = 255
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
colorwheel[col:col + CB, 2] = 255
col += CB
# BM
colorwheel[col:col + BM, 2] = 255
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col + MR, 0] = 255
return colorwheel
def flow_to_png_middlebury(flow):
"""
Convert flow into middlebury color code image
:param flow: optical flow map
:return: optical flow image in middlebury color
"""
flow = flow.transpose([1, 2, 0])
u = flow[:, :, 0]
v = flow[:, :, 1]
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
u[idxUnknow] = 0
v[idxUnknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(-1, np.max(rad))
u = u / (maxrad + np.finfo(float).eps)
v = v / (maxrad + np.finfo(float).eps)
img = compute_color(u, v)
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
img[idx] = 0
return np.uint8(img)
================================================
FILE: utils/interpolation.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import torch
from torch import nn
import torch.nn.functional as tf
def _bchw2bhwc(tensor):
return tensor.transpose(1,2).transpose(2,3)
def _bhwc2bchw(tensor):
return tensor.transpose(2,3).transpose(1,2)
class Meshgrid(nn.Module):
def __init__(self):
super(Meshgrid, self).__init__()
self.width = 0
self.height = 0
self.xx = None
self.yy = None
def _compute_meshgrid(self, width, height):
rangex = torch.arange(0, width)
rangey = torch.arange(0, height)
self.xx = rangex.repeat(height, 1).contiguous()
self.yy = rangey.repeat(width, 1).t().contiguous()
def forward(self, width, height, device=None, dtype=None):
if self.width != width or self.height != height:
self._compute_meshgrid(width=width, height=height)
self.width = width
self.height = height
self.xx = self.xx.to(device=device, dtype=dtype)
self.yy = self.yy.to(device=device, dtype=dtype)
return self.xx, self.yy
class BatchSub2Ind(nn.Module):
def __init__(self):
super(BatchSub2Ind, self).__init__()
self.register_buffer("_offsets", torch.LongTensor())
def forward(self, shape, row_sub, col_sub, out=None):
batch_size = row_sub.size(0)
height, width = shape
ind = row_sub*width + col_sub
torch.arange(batch_size, out=self._offsets)
self._offsets *= (height*width)
if out is None:
return torch.add(ind, self._offsets.view(-1,1,1))
else:
torch.add(ind, self._offsets.view(-1,1,1), out=out)
class Interp2(nn.Module):
def __init__(self, clamp=False):
super(Interp2, self).__init__()
self._clamp = clamp
self._batch_sub2ind = BatchSub2Ind()
self.register_buffer("_x0", torch.LongTensor())
self.register_buffer("_x1", torch.LongTensor())
self.register_buffer("_y0", torch.LongTensor())
self.register_buffer("_y1", torch.LongTensor())
self.register_buffer("_i00", torch.LongTensor())
self.register_buffer("_i01", torch.LongTensor())
self.register_buffer("_i10", torch.LongTensor())
self.register_buffer("_i11", torch.LongTensor())
self.register_buffer("_v00", torch.FloatTensor())
self.register_buffer("_v01", torch.FloatTensor())
self.register_buffer("_v10", torch.FloatTensor())
self.register_buffer("_v11", torch.FloatTensor())
self.register_buffer("_x", torch.FloatTensor())
self.register_buffer("_y", torch.FloatTensor())
def forward(self, v, xq, yq):
batch_size, channels, height, width = v.size()
# clamp if wanted
if self._clamp:
xq.clamp_(0, width - 1)
yq.clamp_(0, height - 1)
# ------------------------------------------------------------------
# Find neighbors
#
# x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1)
# x1 = x0 + 1, x1.clamp_(0, width - 1)
# y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1)
# y1 = y0 + 1, y1.clamp_(0, height - 1)
#
# ------------------------------------------------------------------
self._x0 = torch.floor(xq).long().clamp(0, width - 1)
self._y0 = torch.floor(yq).long().clamp(0, height - 1)
self._x1 = torch.add(self._x0, 1).clamp(0, width - 1)
self._y1 = torch.add(self._y0, 1).clamp(0, height - 1)
# batch_sub2ind
self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00)
self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01)
self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10)
self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11)
# reshape
v_flat = _bchw2bhwc(v).contiguous().view(-1, channels)
torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00)
torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01)
torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10)
torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11)
# local_coords
torch.add(xq, - self._x0.float(), out=self._x)
torch.add(yq, - self._y0.float(), out=self._y)
# weights
w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1)
w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1)
w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1)
w11 = torch.unsqueeze(self._y * self._x, dim=1)
def _reshape(u):
return _bhwc2bchw(u.view(batch_size, height, width, channels))
# values
values = _reshape(self._v00)*w00 + _reshape(self._v01)*w01 \
+ _reshape(self._v10)*w10 + _reshape(self._v11)*w11
if self._clamp:
return values
else:
# find_invalid
invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height)).unsqueeze(dim=1).float()
# maskout invalid
transformed = invalid * torch.zeros_like(values) + (1.0 - invalid)*values
return transformed
class Interp2MaskBinary(nn.Module):
def __init__(self, clamp=False):
super(Interp2MaskBinary, self).__init__()
self._clamp = clamp
self._batch_sub2ind = BatchSub2Ind()
self.register_buffer("_x0", torch.LongTensor())
self.register_buffer("_x1", torch.LongTensor())
self.register_buffer("_y0", torch.LongTensor())
self.register_buffer("_y1", torch.LongTensor())
self.register_buffer("_i00", torch.LongTensor())
self.register_buffer("_i01", torch.LongTensor())
self.register_buffer("_i10", torch.LongTensor())
self.register_buffer("_i11", torch.LongTensor())
self.register_buffer("_v00", torch.FloatTensor())
self.register_buffer("_v01", torch.FloatTensor())
self.register_buffer("_v10", torch.FloatTensor())
self.register_buffer("_v11", torch.FloatTensor())
self.register_buffer("_m00", torch.FloatTensor())
self.register_buffer("_m01", torch.FloatTensor())
self.register_buffer("_m10", torch.FloatTensor())
self.register_buffer("_m11", torch.FloatTensor())
self.register_buffer("_x", torch.FloatTensor())
self.register_buffer("_y", torch.FloatTensor())
def forward(self, v, xq, yq, mask):
batch_size, channels, height, width = v.size()
_, channels_mask, _, _ = mask.size()
if channels_mask != channels:
mask = mask.repeat(1, int(channels/channels_mask), 1, 1)
# clamp if wanted
if self._clamp:
xq.clamp_(0, width - 1)
yq.clamp_(0, height - 1)
# ------------------------------------------------------------------
# Find neighbors
#
# x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1)
# x1 = x0 + 1, x1.clamp_(0, width - 1)
# y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1)
# y1 = y0 + 1, y1.clamp_(0, height - 1)
#
# ------------------------------------------------------------------
self._x0 = torch.floor(xq).long().clamp(0, width - 1)
self._y0 = torch.floor(yq).long().clamp(0, height - 1)
self._x1 = torch.add(self._x0, 1).clamp(0, width - 1)
self._y1 = torch.add(self._y0, 1).clamp(0, height - 1)
# batch_sub2ind
self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00)
self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01)
self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10)
self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11)
# reshape
v_flat = _bchw2bhwc(v).contiguous().view(-1, channels)
torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00)
torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01)
torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10)
torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11)
# reshape
m_flat = _bchw2bhwc(mask).contiguous().view(-1, channels)
torch.index_select(m_flat, dim=0, index=self._i00.view(-1), out=self._m00)
torch.index_select(m_flat, dim=0, index=self._i01.view(-1), out=self._m01)
torch.index_select(m_flat, dim=0, index=self._i10.view(-1), out=self._m10)
torch.index_select(m_flat, dim=0, index=self._i11.view(-1), out=self._m11)
# local_coords
torch.add(xq, - self._x0.float(), out=self._x)
torch.add(yq, - self._y0.float(), out=self._y)
# weights
w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1)
w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1)
w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1)
w11 = torch.unsqueeze(self._y * self._x, dim=1)
def _reshape(u):
return _bhwc2bchw(u.view(batch_size, height, width, channels))
# values
values = _reshape(self._m00) * _reshape(self._v00) * w00 + _reshape(self._m01) * _reshape(
self._v01) * w01 + _reshape(self._m10) * _reshape(self._v10) * w10 + _reshape(self._m11) * _reshape(
self._v11) * w11
m_weights = _reshape(self._m00) * w00 + _reshape(self._m01) * w01 + _reshape(self._m10) * w10 + _reshape(
self._m11) * w11
values = values / (m_weights + 1e-12)
invalid_mask = (((1 - m_weights) / (m_weights + 1e-12)) > 0.5)[:, 0:1, :, :]
if self._clamp:
return values
else:
# find_invalid
invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height) | invalid_mask.squeeze(dim=1)).unsqueeze(dim=1).float()
transformed = invalid * torch.zeros_like(values) + (1.0 - invalid) * values
return transformed, (1 - invalid_mask).float()
def resize2D(inputs, size_targets, mode="bilinear"):
size_inputs = [inputs.size(2), inputs.size(3)]
if all([size_inputs == size_targets]):
return inputs # nothing to do
elif any([size_targets < size_inputs]):
resized = tf.adaptive_avg_pool2d(inputs, size_targets) # downscaling
else:
resized = tf.upsample(inputs, size=size_targets, mode=mode) # upsampling
# correct scaling
return resized
def resize2D_as(inputs, output_as, mode="bilinear"):
size_targets = [output_as.size(2), output_as.size(3)]
return resize2D(inputs, size_targets, mode=mode)