Showing preview only (489K chars total). Download the full file or copy to clipboard to get everything.
Repository: visinf/irr
Branch: master
Commit: dacd07b1dc96
Files: 89
Total size: 240.7 MB
Directory structure:
gitextract_pwdsdpvy/
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── augmentations.py
├── commandline.py
├── configuration.py
├── datasets/
│ ├── __init__.py
│ ├── common.py
│ ├── flyingThings3D.py
│ ├── flyingchairs.py
│ ├── flyingchairsOcc.py
│ ├── kitti_combined.py
│ ├── sintel.py
│ └── transforms.py
├── flyingchairsocc/
│ └── README.md
├── install.sh
├── logger.py
├── losses.py
├── main.py
├── models/
│ ├── IRR_FlowNet.py
│ ├── IRR_PWC.py
│ ├── __init__.py
│ ├── correlation_package/
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── correlation_package_cu9/
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── flownet1s.py
│ ├── flownet1s_irr.py
│ ├── flownet1s_irr_bi.py
│ ├── flownet1s_irr_occ.py
│ ├── flownet1s_irr_occ_bi.py
│ ├── flownet_modules.py
│ ├── irr_modules.py
│ ├── pwc_modules.py
│ ├── pwcnet.py
│ ├── pwcnet_bi.py
│ ├── pwcnet_irr.py
│ ├── pwcnet_irr_bi.py
│ ├── pwcnet_irr_occ.py
│ ├── pwcnet_irr_occ_bi.py
│ ├── pwcnet_occ.py
│ └── pwcnet_occ_bi.py
├── optim/
│ └── __init__.py
├── runtime.py
├── saved_check_point/
│ └── pwcnet/
│ ├── IRR-PWC_flyingchairsOcc/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_kitti/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_sintel/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_things3d/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── PWCNet/
│ │ └── checkpoint_best.ckpt
│ └── PWCNet-irr/
│ └── checkpoint_best.ckpt
├── scripts/
│ ├── IRR-FlowNet_flyingChairsOcc.sh
│ ├── IRR-PWC_flyingChairsOcc.sh
│ ├── IRR-PWC_kitti_train.sh
│ ├── IRR-PWC_kitti_train_full.sh
│ ├── IRR-PWC_sintel_train.sh
│ ├── IRR-PWC_sintel_train_full.sh
│ ├── IRR-PWC_things3d.sh
│ ├── flownet1s.sh
│ ├── flownet1s_irr1.sh
│ ├── flownet1s_irr2.sh
│ ├── pwcnet.sh
│ ├── pwcnet_irr.sh
│ └── validation/
│ ├── IRR-FlowNet_flyingChairs.sh
│ ├── IRR-PWC_flyingChairs.sh
│ ├── IRR-PWC_kitti.sh
│ ├── IRR-PWC_sintel.sh
│ ├── IRR-PWC_things3d.sh
│ ├── flownet1s.sh
│ ├── flownet1s_irr1.sh
│ ├── flownet1s_irr2.sh
│ ├── pwcnet.sh
│ └── pwcnet_irr.sh
├── tools.py
└── utils/
├── __init__.py
├── flow.py
└── interpolation.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
*.so
*.o
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Iterative Residual Refinement <br/> for Joint Optical Flow and Occlusion Estimation
<img src=output.gif>
This repository is the PyTorch implementation of the paper:
**Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation (CVPR 2019)**
[Junhwa Hur](https://sites.google.com/site/hurjunhwa) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/team_members/sroth/sroth.en.jsp)
Department of Computer Science, TU Darmstadt
[[Preprint]](https://arxiv.org/pdf/1904.05290.pdf)   [[Proceeding]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Hur_Iterative_Residual_Refinement_for_Joint_Optical_Flow_and_Occlusion_Estimation_CVPR_2019_paper.pdf)   [[Supplemental]](http://openaccess.thecvf.com/content_CVPR_2019/supplemental/Hur_Iterative_Residual_Refinement_CVPR_2019_supplemental.pdf)
Please cite the paper below if you find our paper and source codes are useful.
@inproceedings{Hur:2019:IRR,
Author = {Junhwa Hur and Stefan Roth},
Booktitle = {CVPR},
Title = {Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation},
Year = {2019}
}
Contact: junhwa.hur[at]visinf.tu-darmstadt.de
## Getting started
This code has been orginally developed under Anaconda(Python 3.6), PyTorch 0.4.1 and CUDA 8.0 on Ubuntu 16.04.
1. Please install the followings:
- Anaconda
- PyTorch (now compatible with __PyTorch 1.5.0__)
- tqdm (`conda install -c conda-forge tqdm==4.40.0`)
- (any missing packages that the code requires)
2. The datasets used for this projects are followings:
- [FlyingChairsOcc dataset](https://github.com/visinf/irr/tree/master/flyingchairsocc)
- [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
- [MPI Sintel Dataset](http://sintel.is.tue.mpg.de/downloads) + [revised occlusion GT](https://download.visinf.tu-darmstadt.de/data/flyingchairs_occ/occlusions_rev.zip)
- [KITTI Optical Flow 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) and [KITTI Optical Flow 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=flow)
## Training
The `scripts` folder contains training scripts of experiments demonstrated in the paper.
To train the model, you can simply run the script file, e.g., `./IRR-PWC_flyingChairsOcc.sh`.
In script files, please configure your own experiment directory (EXPERIMENTS_HOME) and dataset directory in your local system (e.g., SINTEL_HOME or KITTI_HOME).
## Pretrained Models
The `saved_check_point` contains the pretrained models of *i)* baseline, *ii)* baseline + irr, and *iii)* full models.
Additional pretrained models in the ablations study (Table 1 in the main paper) and their training scripts are available upon request.
## Inference
The scripts for testing the pre-trained models are located in `scripts/validation`.
## Acknowledgement
Portions of the source code (e.g., training pipeline, runtime, argument parser, and logger) are from [Jochen Gast](https://scholar.google.com/citations?user=tmRcFacAAAAJ&hl=en)
================================================
FILE: __init__.py
================================================
================================================
FILE: augmentations.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from utils.interpolation import Interp2, Interp2MaskBinary
from utils.interpolation import Meshgrid
import numpy as np
def denormalize_coords(xx, yy, width, height):
""" scale indices from [-1, 1] to [0, width/height] """
xx = 0.5 * (width - 1.0) * (xx.float() + 1.0)
yy = 0.5 * (height - 1.0) * (yy.float() + 1.0)
return xx, yy
def normalize_coords(xx, yy, width, height):
""" scale indices from [0, width/height] to [-1, 1] """
xx = (2.0 / (width - 1.0)) * xx.float() - 1.0
yy = (2.0 / (height - 1.0)) * yy.float() - 1.0
return xx, yy
def apply_transform_to_params(theta0, theta_transform):
a1 = theta0[:, 0]
a2 = theta0[:, 1]
a3 = theta0[:, 2]
a4 = theta0[:, 3]
a5 = theta0[:, 4]
a6 = theta0[:, 5]
#
b1 = theta_transform[:, 0]
b2 = theta_transform[:, 1]
b3 = theta_transform[:, 2]
b4 = theta_transform[:, 3]
b5 = theta_transform[:, 4]
b6 = theta_transform[:, 5]
#
c1 = a1 * b1 + a4 * b2
c2 = a2 * b1 + a5 * b2
c3 = b3 + a3 * b1 + a6 * b2
c4 = a1 * b4 + a4 * b5
c5 = a2 * b4 + a5 * b5
c6 = b6 + a3 * b4 + a6 * b5
#
new_theta = torch.stack([c1, c2, c3, c4, c5, c6], dim=1)
return new_theta
class _IdentityParams(nn.Module):
def __init__(self):
super(_IdentityParams, self).__init__()
self._batch_size = 0
self.register_buffer("_o", torch.FloatTensor())
self.register_buffer("_i", torch.FloatTensor())
def _update(self, batch_size):
torch.zeros([batch_size, 1], out=self._o)
torch.ones([batch_size, 1], out=self._i)
return torch.cat([self._i, self._o, self._o, self._o, self._i, self._o], dim=1)
def forward(self, batch_size):
if self._batch_size != batch_size:
self._identity_params = self._update(batch_size)
self._batch_size = batch_size
return self._identity_params
class RandomMirror(nn.Module):
def __init__(self, vertical=True, p=0.5):
super(RandomMirror, self).__init__()
self._batch_size = 0
self._p = p
self._vertical = vertical
self.register_buffer("_mirror_probs", torch.FloatTensor())
def update_probs(self, batch_size):
torch.ones([batch_size, 1], out=self._mirror_probs)
self._mirror_probs *= self._p
def forward(self, theta1, theta2):
batch_size = theta1.size(0)
if batch_size != self._batch_size:
self.update_probs(batch_size)
self._batch_size = batch_size
# apply random sign to a1 a2 a3 (these are the guys responsible for x)
sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0)
i = torch.ones_like(sign)
horizontal_mirror = torch.cat([sign, sign, sign, i, i, i], dim=1)
theta1 *= horizontal_mirror
theta2 *= horizontal_mirror
# apply random sign to a4 a5 a6 (these are the guys responsible for y)
if self._vertical:
sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0)
vertical_mirror = torch.cat([i, i, i, sign, sign, sign], dim=1)
theta1 *= vertical_mirror
theta2 *= vertical_mirror
return theta1, theta2
class RandomCrop(nn.Module):
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, crop):
super(RandomCrop, self).__init__()
self._crop_size = crop
self.register_buffer("_x", torch.LongTensor())
self.register_buffer("_y", torch.LongTensor())
def forward(self, im1, im2, flo):
batch_size, _, height, width = im1.size()
crop_height, crop_width = self._crop_size
# check whether there is anything to do
if any(self._size < 1):
return im1, im2, flo
# get starting positions
self._x.random_(0, width - crop_width)
self._y.random_(0, height - crop_height)
im1 = im1[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]
im2 = im2[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]
flo = flo[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]
class RandomAffineFlow(nn.Module):
def __init__(self, args, addnoise=True):
super(RandomAffineFlow, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror()
self._addnoise = addnoise
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
transformed = self._flow_interp2(new_flow, xq, yq)
return transformed
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo = example_dict["target1"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.2,
min_zoom=1.0, max_zoom=1.5,
min_squeeze=0.86, max_squeeze=1.16,
min_rotate=-0.2, max_rotate=0.2,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.015,
min_zoom=0.985, max_zoom=1.015,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.015, max_rotate=0.015,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo = self.transform_flow(flo, theta1, theta2)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
# construct updated dictionaries
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo
return example_dict
class RandomAffineFlowOcc(nn.Module):
def __init__(self, args, addnoise=True, crop=None):
super(RandomAffineFlowOcc, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror()
self._addnoise = addnoise
self._crop = crop
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
self.register_buffer("_x", torch.IntTensor(1))
self.register_buffer("_y", torch.IntTensor(1))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1. - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
transformed = self._flow_interp2(new_flow, xq, yq)
return transformed
def check_out_of_bound(self, flow, occ, batch_size):
_, _, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
xx = xx.expand(batch_size, -1, -1) + u
yy = yy.expand(batch_size, -1, -1) + v
out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)
occ = torch.clamp(out_of_bound + occ, 0, 1)
return occ
def random_crop(self, im1, im2, flo_f, flo_b, occ1, occ2):
_, _, height, width = im1.size()
crop_height, crop_width = self._crop
# get starting positions
self._x.random_(0, width - crop_width + 1)
self._y.random_(0, height - crop_height + 1)
str_x = int(self._x)
str_y = int(self._y)
end_x = int(self._x + crop_width)
end_y = int(self._y + crop_height)
im1 = im1[:, :, str_y:end_y, str_x:end_x]
im2 = im2[:, :, str_y:end_y, str_x:end_x]
flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]
flo_b = flo_b[:, :, str_y:end_y, str_x:end_x]
occ1 = occ1[:, :, str_y:end_y, str_x:end_x]
occ2 = occ2[:, :, str_y:end_y, str_x:end_x]
return im1, im2, flo_f, flo_b, occ1, occ2
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo_f = example_dict["target1"]
flo_b = example_dict["target2"]
occ1 = example_dict["target_occ1"]
occ2 = example_dict["target_occ2"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.2,
min_zoom=1.0, max_zoom=1.5,
min_squeeze=0.86, max_squeeze=1.16,
min_rotate=-0.2, max_rotate=0.2,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.015,
min_zoom=0.985, max_zoom=1.015,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.015, max_rotate=0.015,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo_f = self.transform_flow(flo_f, theta1, theta2)
flo_b = self.transform_flow(flo_b, theta2, theta1)
occ1 = self.transform_image(occ1, theta1)
occ2 = self.transform_image(occ2, theta2)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
if self._crop is not None:
im1, im2, flo_f, flo_b, occ1, occ2 = self.random_crop(im1, im2, flo_f, flo_b, occ1, occ2)
occ1 = self.check_out_of_bound(flo_f, occ1, batch_size)
occ2 = self.check_out_of_bound(flo_b, occ2, batch_size)
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo_f
example_dict["target2"] = flo_b
example_dict["target_occ1"] = occ1
example_dict["target_occ2"] = occ2
return example_dict
class RandomAffineFlowOccSintel(nn.Module):
def __init__(self, args, addnoise=True, crop=None):
super(RandomAffineFlowOccSintel, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror()
self._addnoise = addnoise
self._crop = crop
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
self.register_buffer("_x", torch.IntTensor(1))
self.register_buffer("_y", torch.IntTensor(1))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
transformed = self._flow_interp2(new_flow, xq, yq)
return transformed
def check_out_of_bound(self, flow, occ, batch_size):
_, _, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)
xx = torch.unsqueeze(xx, dim=0)
yy = torch.unsqueeze(yy, dim=0)
xx = xx.expand(batch_size, -1, -1) + u
yy = yy.expand(batch_size, -1, -1) + v
out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)
occ = torch.clamp(out_of_bound + occ, 0, 1)
return occ
def random_crop(self, im1, im2, flo_f, occ1):
_, _, height, width = im1.size()
crop_height, crop_width = self._crop
# get starting positions
self._x.random_(0, width - crop_width + 1)
self._y.random_(0, height - crop_height + 1)
str_x = int(self._x)
str_y = int(self._y)
end_x = int(self._x + crop_width)
end_y = int(self._y + crop_height)
im1 = im1[:, :, str_y:end_y, str_x:end_x]
im2 = im2[:, :, str_y:end_y, str_x:end_x]
flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]
occ1 = occ1[:, :, str_y:end_y, str_x:end_x]
return im1, im2, flo_f, occ1
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo_f = example_dict["target1"]
occ1 = example_dict["target_occ1"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.2,
min_zoom=1.0, max_zoom=1.5,
min_squeeze=0.86, max_squeeze=1.16,
min_rotate=-0.2, max_rotate=0.2,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.015,
min_zoom=0.985, max_zoom=1.015,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.015, max_rotate=0.015,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo_f = self.transform_flow(flo_f, theta1, theta2)
occ1 = self.transform_image(occ1, theta1)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
if self._crop is not None:
im1, im2, flo_f, occ1 = self.random_crop(im1, im2, flo_f, occ1)
occ1 = self.check_out_of_bound(flo_f, occ1, batch_size)
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo_f
example_dict["target_occ1"] = occ1
return example_dict
class RandomAffineFlowOccKITTI(nn.Module):
def __init__(self, args, addnoise=True, crop=None):
super(RandomAffineFlowOccKITTI, self).__init__()
self._args = args
self._interp2 = Interp2(clamp=False)
self._flow_interp2 = Interp2MaskBinary(clamp=False)
self._meshgrid = Meshgrid()
self._identity = _IdentityParams()
self._random_mirror = RandomMirror(vertical=False)
self._addnoise = addnoise
self._crop = crop
self.register_buffer("_noise1", torch.FloatTensor())
self.register_buffer("_noise2", torch.FloatTensor())
self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1]))
self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1]))
self.register_buffer("_x", torch.IntTensor(1))
self.register_buffer("_y", torch.IntTensor(1))
def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):
xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
if offset_x is not None:
xx = xx + offset_x
if offset_y is not None:
yy = yy + offset_y
a1 = thetas[:, 0].contiguous().view(-1, 1, 1)
a2 = thetas[:, 1].contiguous().view(-1, 1, 1)
a3 = thetas[:, 2].contiguous().view(-1, 1, 1)
a4 = thetas[:, 3].contiguous().view(-1, 1, 1)
a5 = thetas[:, 4].contiguous().view(-1, 1, 1)
a6 = thetas[:, 5].contiguous().view(-1, 1, 1)
xx, yy = normalize_coords(xx, yy, width=width, height=height)
xq = a1 * xx + a2 * yy + a3
yq = a4 * xx + a5 * yy + a6
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def transform_coords(self, width, height, thetas):
xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)
xx, yy = normalize_coords(xx1, yy1, width=width, height=height)
def _unsqueeze12(u):
return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)
a1 = _unsqueeze12(thetas[:, 0])
a2 = _unsqueeze12(thetas[:, 1])
a3 = _unsqueeze12(thetas[:, 2])
a4 = _unsqueeze12(thetas[:, 3])
a5 = _unsqueeze12(thetas[:, 4])
a6 = _unsqueeze12(thetas[:, 5])
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = xx - a3
yhat = yy - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
return xq, yq
def find_invalid(self, width, height, thetas):
x = self._xbounds
y = self._ybounds
#
a1 = torch.unsqueeze(thetas[:, 0], dim=1)
a2 = torch.unsqueeze(thetas[:, 1], dim=1)
a3 = torch.unsqueeze(thetas[:, 2], dim=1)
a4 = torch.unsqueeze(thetas[:, 3], dim=1)
a5 = torch.unsqueeze(thetas[:, 4], dim=1)
a6 = torch.unsqueeze(thetas[:, 5], dim=1)
#
z = a1 * a5 - a2 * a4
b1 = a5 / z
b2 = - a2 / z
b4 = - a4 / z
b5 = a1 / z
#
xhat = x - a3
yhat = y - a6
xq = b1 * xhat + b2 * yhat
yq = b4 * xhat + b5 * yhat
xq, yq = denormalize_coords(xq, yq, width=width, height=height)
#
invalid = (
(xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)
).sum(dim=1, keepdim=True) > 0
return invalid
def apply_random_transforms_to_params(self,
theta0,
max_translate,
min_zoom, max_zoom,
min_squeeze, max_squeeze,
min_rotate, max_rotate,
validate_size=None):
max_translate *= 0.5
batch_size = theta0.size(0)
height, width = validate_size
# collect valid params here
thetas = torch.zeros_like(theta0)
zoom = theta0.new(batch_size, 1).zero_()
squeeze = torch.zeros_like(zoom)
tx = torch.zeros_like(zoom)
ty = torch.zeros_like(zoom)
phi = torch.zeros_like(zoom)
invalid = torch.ones_like(zoom).byte()
while invalid.sum() > 0:
# random sampling
zoom.uniform_(min_zoom, max_zoom)
squeeze.uniform_(min_squeeze, max_squeeze)
tx.uniform_(-max_translate, max_translate)
ty.uniform_(-max_translate, max_translate)
phi.uniform_(min_rotate, max_rotate)
# construct affine parameters
sx = zoom * squeeze
sy = zoom / squeeze
sin_phi = torch.sin(phi)
cos_phi = torch.cos(phi)
b1 = cos_phi * sx
b2 = sin_phi * sy
b3 = tx
b4 = - sin_phi * sx
b5 = cos_phi * sy
b6 = ty
theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas
# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
# here we should have good thetas within borders
return thetas
def transform_image(self, images, thetas):
batch_size, channels, height, width = images.size()
xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)
transformed = self._interp2(images, xq, yq)
return transformed
def transform_flow(self, flow, theta1, theta2, valid_mask):
batch_size, channels, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
# inverse transform coords
x0, y0 = self.inverse_transform_coords(
width=width, height=height, thetas=theta1)
x1, y1 = self.inverse_transform_coords(
width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)
# subtract and create new flow
u = x1 - x0
v = y1 - y0
new_flow = torch.stack([u, v], dim=1)
# transform coords
xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)
# interp2
# transformed = self._interp2(new_flow, xq, yq)
transformed, valid_mask = self._flow_interp2(new_flow, xq, yq, valid_mask)
return transformed, valid_mask
def check_out_of_bound(self, flow, occ, batch_size):
_, _, height, width = flow.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)
xx = torch.unsqueeze(xx, dim=0).float()
yy = torch.unsqueeze(yy, dim=0).float()
xx = xx.expand(batch_size, -1, -1) + u
yy = yy.expand(batch_size, -1, -1) + v
out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)
occ = torch.clamp(out_of_bound + occ, 0, 1)
return occ
def random_crop(self, im1, im2, flo_f, valid_mask):
_, _, height, width = im1.size()
crop_height, crop_width = self._crop
# get starting positions
self._x.random_(0, width - crop_width + 1)
self._y.random_(0, height - crop_height + 1)
str_x = int(self._x)
str_y = int(self._y)
end_x = int(self._x + crop_width)
end_y = int(self._y + crop_height)
im1 = im1[:, :, str_y:end_y, str_x:end_x]
im2 = im2[:, :, str_y:end_y, str_x:end_x]
flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]
valid_mask = valid_mask[:, :, str_y:end_y, str_x:end_x]
return im1, im2, flo_f, valid_mask
def forward(self, example_dict):
im1 = example_dict["input1"]
im2 = example_dict["input2"]
flo_f = example_dict["target1"]
valid_mask = example_dict["input_valid"]
batch_size = im1.size(0)
height = im1.size(2)
width = im1.size(3)
# identity = no transform
theta0 = self._identity(batch_size)
# # global transform
theta1 = self.apply_random_transforms_to_params(
theta0,
max_translate=0.04,
min_zoom=0.98, max_zoom=1.02,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.01, max_rotate=0.01,
validate_size=[height, width])
# relative transform
theta2 = self.apply_random_transforms_to_params(
theta1,
max_translate=0.005,
min_zoom=0.99, max_zoom=1.01,
min_squeeze=1.0, max_squeeze=1.0,
min_rotate=-0.01, max_rotate=0.01,
validate_size=[height, width])
# random flip images
theta1, theta2 = self._random_mirror(theta1, theta2)
im1 = self.transform_image(im1, theta1)
im2 = self.transform_image(im2, theta2)
flo_f, valid_mask = self.transform_flow(flo_f, theta1, theta2, valid_mask)
if self._addnoise:
stddev = np.random.uniform(0.0, 0.04)
self._noise1.resize_as_(im1)
self._noise2.resize_as_(im2)
self._noise1.normal_(std=stddev)
self._noise2.normal_(std=stddev)
im1 += self._noise1
im2 += self._noise2
im1.clamp_(0.0, 1.0)
im2.clamp_(0.0, 1.0)
if self._crop is not None:
im1, im2, flo_f, valid_mask = self.random_crop(im1, im2, flo_f, valid_mask)
example_dict["input1"] = im1
example_dict["input2"] = im2
example_dict["target1"] = flo_f
example_dict["input_valid"] = valid_mask
return example_dict
================================================
FILE: commandline.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import argparse
import colorama
import inspect
import os
import sys
import torch
import datasets
import losses
import models
import augmentations
import tools
import logger
import logging
import optim
def _get_type_from_arg(arg):
if isinstance(arg, bool):
return tools.str2bool
else:
return type(arg)
def _add_arguments_for_module(parser,
module,
name,
default_class,
add_class_argument=True, # whether to add class choice as argument
include_classes="*",
exclude_classes=[],
exclude_params=["self","args"],
param_defaults={}, # allows to overwrite any default param
forced_default_types={}, # allows to set types for known arguments
unknown_default_types={}): # allows to set types for unknown arguments
# -------------------------------------------------------------------------
# Determine possible choices from class names in module, possibly apply include/exclude filters
# -------------------------------------------------------------------------
module_dict = tools.module_classes_to_dict(
module, include_classes=include_classes, exclude_classes=exclude_classes)
# -------------------------------------------------------------------------
# Parse known arguments to determine choice for argument name
# -------------------------------------------------------------------------
if add_class_argument:
parser.add_argument(
"--%s" % name, type=str, default=default_class, choices=module_dict.keys())
known_args = parser.parse_known_args(sys.argv[1:])[0]
else:
# build a temporary parser, and do not add the class as argument
tmp_parser = argparse.ArgumentParser()
tmp_parser.add_argument(
"--%s" % name, type=str, default=default_class, choices=module_dict.keys())
known_args = tmp_parser.parse_known_args(sys.argv[1:])[0]
class_name = vars(known_args)[name]
# -------------------------------------------------------------------------
# If class is None, there is no point in trying to parse further arguments
# -------------------------------------------------------------------------
if class_name is None:
return
# -------------------------------------------------------------------------
# Get constructor of that argument choice
# -------------------------------------------------------------------------
class_constructor = module_dict[class_name]
# -------------------------------------------------------------------------
# Determine constructor argument names and defaults
# -------------------------------------------------------------------------
try:
argspec = inspect.getargspec(class_constructor.__init__)
argspec_defaults = argspec.defaults if argspec.defaults is not None else []
full_args = argspec.args
default_args_dict = dict(zip(argspec.args[-len(argspec_defaults):], argspec_defaults))
except TypeError:
print(argspec)
print(argspec.defaults)
raise ValueError("unknown_default_types should be adjusted for module: '%s.py'" % name)
# -------------------------------------------------------------------------
# Add sub_arguments
# -------------------------------------------------------------------------
for argname in full_args:
# ---------------------------------------------------------------------
# Skip
# ---------------------------------------------------------------------
if argname in exclude_params:
continue
# ---------------------------------------------------------------------
# Sub argument name
# ---------------------------------------------------------------------
sub_arg_name = "%s_%s" % (name, argname)
# ---------------------------------------------------------------------
# If a default argument is given, take that one
# ---------------------------------------------------------------------
if argname in param_defaults.keys():
parser.add_argument(
"--%s" % sub_arg_name,
type=_get_type_from_arg(param_defaults[argname]),
default=param_defaults[argname])
# ---------------------------------------------------------------------
# If a default parameter can be inferred from the module, pick that one
# ---------------------------------------------------------------------
elif argname in default_args_dict.keys():
# -----------------------------------------------------------------
# Check for forced default types
# -----------------------------------------------------------------
if argname in forced_default_types.keys():
argtype = forced_default_types[argname]
else:
argtype = _get_type_from_arg(default_args_dict[argname])
parser.add_argument(
"--%s" % sub_arg_name, type=argtype, default=default_args_dict[argname])
# ---------------------------------------------------------------------
# Take from the unkowns list
# ---------------------------------------------------------------------
elif argname in unknown_default_types.keys():
parser.add_argument("--%s" % sub_arg_name, type=unknown_default_types[argname])
else:
raise ValueError(
"Do not know how to handle argument '%s' for class '%s'" % (argname, name))
def _add_special_arguments(parser):
# -------------------------------------------------------------------------
# Known arguments so far
# -------------------------------------------------------------------------
known_args = vars(parser.parse_known_args(sys.argv[1:])[0])
# -------------------------------------------------------------------------
# Add special arguments for training
# -------------------------------------------------------------------------
training_loss = known_args["training_loss"]
if training_loss is not None:
parser.add_argument("--training_key", type=str, default="total_loss")
# -------------------------------------------------------------------------
# Add special arguments for validation
# -------------------------------------------------------------------------
validation_loss = known_args["validation_loss"]
if validation_loss is not None:
parser.add_argument("--validation_key", type=str, default="total_loss")
parser.add_argument("--validation_key_minimize", type=tools.str2bool, default=True)
# -------------------------------------------------------------------------
# Add special arguments for checkpoints
# -------------------------------------------------------------------------
checkpoint = known_args["checkpoint"]
if checkpoint is not None:
parser.add_argument(
"--checkpoint_mode", type=str, default="resume_from_latest",
choices=["resume_from_latest", "resume_from_best"])
parser.add_argument(
"--checkpoint_include_params", type=tools.str2list, default="[*]")
parser.add_argument(
"--checkpoint_exclude_params", type=tools.str2list, default="[]")
# -------------------------------------------------------------------------
# Add special arguments for optimizer groups
# -------------------------------------------------------------------------
parser.add_argument("--optimizer_group", action="append", type=tools.str2dict, default=None)
def _parse_arguments():
# -------------------------------------------------------------------------
# Argument parser and shortcut function to add arguments
# -------------------------------------------------------------------------
parser = argparse.ArgumentParser()
add = parser.add_argument
# -------------------------------------------------------------------------
# Standard arguments
# -------------------------------------------------------------------------
add("--batch_size", type=int, default=1)
add("--batch_size_val", type=int, default=1)
add("--checkpoint", type=tools.str2str_or_none, default=None)
add("--cuda", type=tools.str2bool, default=True)
add("--evaluation", type=tools.str2bool, default=False)
add("--name", default="run", type=str)
add("--num_workers", type=int, default=4)
add("--save", "-s", default="/tmp/work", type=str)
add("--seed", type=int, default=1)
add("--start_epoch", type=int, default=1)
add("--total_epochs", type=int, default=10)
add("--save_result_path_name", default="", type=str)
add("--save_result_img", type=tools.str2bool, default=False)
add("--save_result_occ", type=tools.str2bool, default=False)
add("--save_result_flo", type=tools.str2bool, default=False)
add("--save_result_png", type=tools.str2bool, default=False)
add("--save_result_bidirection", type=tools.str2bool, default=False)
add("--num_iters", type=int, default=1)
# -------------------------------------------------------------------------
# Arguments inferred from losses
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
losses,
name="training_loss",
default_class=None,
exclude_classes=["_*", "Variable"],
exclude_params=["self","args"])
_add_arguments_for_module(
parser,
losses,
name="validation_loss",
default_class=None,
exclude_classes=["_*", "Variable"],
exclude_params=["self","args"])
# -------------------------------------------------------------------------
# Arguments inferred from models
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
models,
name="model",
default_class="FlowNet1S",
exclude_classes=["_*", "Variable"],
exclude_params=["self","args"])
# -------------------------------------------------------------------------
# Arguments inferred from augmentations for training
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
augmentations,
name="training_augmentation",
default_class=None,
exclude_classes=["_*"],
exclude_params=["self","args"],
forced_default_types={"crop": tools.str2intlist})
# -------------------------------------------------------------------------
# Arguments inferred from augmentations for validation
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
augmentations,
name="validation_augmentation",
default_class=None,
exclude_classes=["_*"],
exclude_params=["self","args"])
# -------------------------------------------------------------------------
# Arguments inferred from datasets for training
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
datasets,
name="training_dataset",
default_class=None,
exclude_params=["self", "args", "is_cropped"],
exclude_classes=["_*"],
unknown_default_types={"root": str})
# -------------------------------------------------------------------------
# Arguments inferred from datasets for validation
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
datasets,
name="validation_dataset",
default_class=None,
exclude_params=["self", "args", "is_cropped"],
exclude_classes=["_*"],
unknown_default_types={"root": str})
# -------------------------------------------------------------------------
# Arguments inferred from PyTorch optimizers
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
optim,
name="optimizer",
default_class="Adam",
exclude_classes=["_*","Optimizer", "constructor"],
exclude_params=["self", "args", "params"],
forced_default_types={"lr": float,
"momentum": float,
"dampening": float,
"weight_decay": float,
"nesterov": tools.str2bool})
# -------------------------------------------------------------------------
# Arguments inferred from PyTorch lr schedulers
# -------------------------------------------------------------------------
_add_arguments_for_module(
parser,
torch.optim.lr_scheduler,
name="lr_scheduler",
default_class=None,
exclude_classes=["_*","Optimizer"],
exclude_params=["self", "args", "optimizer"],
unknown_default_types={"T_max": int,
"lr_lambda": str,
"step_size": int,
"milestones": tools.str2intlist,
"gamma": float})
# -------------------------------------------------------------------------
# Special arguments
# -------------------------------------------------------------------------
_add_special_arguments(parser)
# -------------------------------------------------------------------------
# Parse arguments
# -------------------------------------------------------------------------
args = parser.parse_args()
# -------------------------------------------------------------------------
# Parse default arguments from a dummy commandline not specifying any args
# -------------------------------------------------------------------------
defaults = vars(parser.parse_known_args(['--dummy'])[0])
# -------------------------------------------------------------------------
# Consistency checks
# -------------------------------------------------------------------------
args.cuda = args.cuda and torch.cuda.is_available()
return args, defaults
def postprocess_args(args):
# ----------------------------------------------------------------------------
# Get appropriate class constructors from modules
# ----------------------------------------------------------------------------
args.model_class = tools.module_classes_to_dict(models)[args.model]
if args.optimizer is not None:
optimizer_classes = tools.module_classes_to_dict(optim)
args.optimizer_class = optimizer_classes[args.optimizer]
if args.training_loss is not None:
loss_classes = tools.module_classes_to_dict(losses)
args.training_loss_class = loss_classes[args.training_loss]
if args.validation_loss is not None:
loss_classes = tools.module_classes_to_dict(losses)
args.validation_loss_class = loss_classes[args.validation_loss]
if args.lr_scheduler is not None:
scheduler_classes = tools.module_classes_to_dict(torch.optim.lr_scheduler)
args.lr_scheduler_class = scheduler_classes[args.lr_scheduler]
if args.training_dataset is not None:
dataset_classes = tools.module_classes_to_dict(datasets)
args.training_dataset_class = dataset_classes[args.training_dataset]
if args.validation_dataset is not None:
dataset_classes = tools.module_classes_to_dict(datasets)
args.validation_dataset_class = dataset_classes[args.validation_dataset]
if args.training_augmentation is not None:
augmentation_classes = tools.module_classes_to_dict(augmentations)
args.training_augmentation_class = augmentation_classes[args.training_augmentation]
if args.validation_augmentation is not None:
augmentation_classes = tools.module_classes_to_dict(augmentations)
args.validation_augmentation_class = augmentation_classes[args.validation_augmentation]
return args
def setup_logging_and_parse_arguments(blocktitle):
# ----------------------------------------------------------------------------
# Get parse commandline and default arguments
# ----------------------------------------------------------------------------
args, defaults = _parse_arguments()
# ----------------------------------------------------------------------------
# Setup logbook before everything else
# ----------------------------------------------------------------------------
logger.configure_logging(os.path.join(args.save, 'logbook.txt'))
# ----------------------------------------------------------------------------
# Write arguments to file, as txt
# ----------------------------------------------------------------------------
tools.write_dictionary_to_file(
sorted(vars(args).items()),
filename=os.path.join(args.save, 'args.txt'))
# ----------------------------------------------------------------------------
# Log arguments
# ----------------------------------------------------------------------------
with logger.LoggingBlock(blocktitle, emph=True):
for argument, value in sorted(vars(args).items()):
reset = colorama.Style.RESET_ALL
color = reset if value == defaults[argument] else colorama.Fore.CYAN
logging.info('{}{}: {}{}'.format(color, argument, value, reset))
# ----------------------------------------------------------------------------
# Postprocess
# ----------------------------------------------------------------------------
args = postprocess_args(args)
return args
================================================
FILE: configuration.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import os
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
import logger
import tools
import logging
import shutil
import random
import fnmatch
# ---------------------------------------------------
# Class that contains both the network model and loss
# ---------------------------------------------------
class ModelAndLoss(nn.Module):
def __init__(self, args, model, training_loss, evaluation_loss=None):
super(ModelAndLoss, self).__init__()
self._model = model
self._training_loss = training_loss
self._evaluation_loss = evaluation_loss
@property
def training_loss(self):
return self._training_loss
@property
def evaluation_loss(self):
return self._evaluation_loss
@property
def model(self):
return self._model
def num_parameters(self):
return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
# -------------------------------------------------------------
# Note: We merge inputs and targets into a single dictionary !
# -------------------------------------------------------------
def forward(self, example_dict):
# -------------------------------------
# Run forward pass
# -------------------------------------
output_dict = self._model(example_dict)
# -------------------------------------
# Compute losses
# -------------------------------------
if self.training:
loss_dict = self._training_loss(output_dict, example_dict)
else:
loss_dict = self._evaluation_loss(output_dict, example_dict)
# -------------------------------------
# Return losses and outputs
# -------------------------------------
return loss_dict, output_dict
def configure_runtime_augmentations(args):
with logger.LoggingBlock("Runtime Augmentations", emph=True):
training_augmentation = None
validation_augmentation = None
# ----------------------------------------------------
# Training Augmentation
# ----------------------------------------------------
if args.training_augmentation is not None:
kwargs = tools.kwargs_from_args(args, "training_augmentation")
logging.info("training_augmentation: %s" % args.training_augmentation)
for param, default in sorted(kwargs.items()):
logging.info(" %s: %s" % (param, default))
kwargs["args"] = args
training_augmentation = tools.instance_from_kwargs(
args.training_augmentation_class, kwargs)
if args.cuda:
training_augmentation = training_augmentation.cuda()
else:
logging.info("training_augmentation: None")
# ----------------------------------------------------
# Validation Augmentation
# ----------------------------------------------------
if args.validation_augmentation is not None:
kwargs = tools.kwargs_from_args(args, "validation_augmentation")
logging.info("validation_augmentation: %s" % args.validation_augmentation)
for param, default in sorted(kwargs.items()):
logging.info(" %s: %s" % (param, default))
kwargs["args"] = args
validation_augmentation = tools.instance_from_kwargs(
args.validation_augmentation_class, kwargs)
if args.cuda:
validation_augmentation = validation_augmentation.cuda()
else:
logging.info("validation_augmentation: None")
return training_augmentation, validation_augmentation
def configure_model_and_loss(args):
# ----------------------------------------------------
# Dynamically load model and loss class with parameters
# passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments
# ----------------------------------------------------
with logger.LoggingBlock("Model and Loss", emph=True):
# ----------------------------------------------------
# Model
# ----------------------------------------------------
kwargs = tools.kwargs_from_args(args, "model")
kwargs["args"] = args
model = tools.instance_from_kwargs(args.model_class, kwargs)
# ----------------------------------------------------
# Training loss
# ----------------------------------------------------
training_loss = None
if args.training_loss is not None:
kwargs = tools.kwargs_from_args(args, "training_loss")
kwargs["args"] = args
training_loss = tools.instance_from_kwargs(args.training_loss_class, kwargs)
# ----------------------------------------------------
# Validation loss
# ----------------------------------------------------
validation_loss = None
if args.validation_loss is not None:
kwargs = tools.kwargs_from_args(args, "validation_loss")
kwargs["args"] = args
validation_loss = tools.instance_from_kwargs(args.validation_loss_class, kwargs)
# ----------------------------------------------------
# Model and loss
# ----------------------------------------------------
model_and_loss = ModelAndLoss(args, model, training_loss, validation_loss)
# -----------------------------------------------------------
# If Cuda, transfer model to Cuda and wrap with DataParallel.
# -----------------------------------------------------------
if args.cuda:
model_and_loss = model_and_loss.cuda()
# ---------------------------------------------------------------
# Report some network statistics
# ---------------------------------------------------------------
logging.info("Batch Size: %i" % args.batch_size)
logging.info("GPGPU: Cuda") if args.cuda else logging.info("GPGPU: off")
logging.info("Network: %s" % args.model)
logging.info("Number of parameters: %i" % tools.x2module(model_and_loss).num_parameters())
if training_loss is not None:
logging.info("Training Key: %s" % args.training_key)
logging.info("Training Loss: %s" % args.training_loss)
if validation_loss is not None:
logging.info("Validation Key: %s" % args.validation_key)
logging.info("Validation Loss: %s" % args.validation_loss)
return model_and_loss
def configure_random_seed(args):
with logger.LoggingBlock("Random Seeds", emph=True):
# python
seed = args.seed
random.seed(seed)
logging.info("Python seed: %i" % seed)
# numpy
seed += 1
np.random.seed(seed)
logging.info("Numpy seed: %i" % seed)
# torch
seed += 1
torch.manual_seed(seed)
logging.info("Torch CPU seed: %i" % seed)
# torch cuda
seed += 1
torch.cuda.manual_seed(seed)
logging.info("Torch CUDA seed: %i" % seed)
# --------------------------------------------------------------------------
# Checkpoint loader/saver.
# --------------------------------------------------------------------------
class CheckpointSaver:
def __init__(self,
prefix="checkpoint",
latest_postfix="_latest",
best_postfix="_best",
model_key="state_dict",
extension=".ckpt"):
self._prefix = prefix
self._model_key = model_key
self._latest_postfix = latest_postfix
self._best_postfix = best_postfix
self._extension = extension
# the purpose of rewriting the loading function is we sometimes want to
# initialize parameters in modules without knowing the dimensions at runtime
#
# This function here will resize these parameters to whatever size required.
def _load_state_dict_into_module(self, state_dict, module, strict=True):
own_state = module.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].resize_as_(param)
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
def restore(self, filename, model_and_loss, include_params="*", exclude_params=()):
# -----------------------------------------------------------------------------------------
# Make sure file exists
# -----------------------------------------------------------------------------------------
if not os.path.isfile(filename):
logging.info("Could not find checkpoint file '%s'!" % filename)
quit()
# -----------------------------------------------------------------------------------------
# Load checkpoint from file including the state_dict
# -----------------------------------------------------------------------------------------
checkpoint_with_state = torch.load(filename)
# -----------------------------------------------------------------------------------------
# Load filtered state dictionary
# -----------------------------------------------------------------------------------------
state_dict = checkpoint_with_state[self._model_key]
restore_keys = tools.filter_list_of_strings(
state_dict.keys(),
include=include_params,
exclude=exclude_params)
state_dict = {key: value for key, value in state_dict.items() if key in restore_keys}
self._load_state_dict_into_module(state_dict, model_and_loss)
# logging.info(" Restore keys:")
# for key in restore_keys:
# logging.info(" %s" % key)
# -----------------------------------------------------------------------------------------
# Get checkpoint statistics without the state dict
# -----------------------------------------------------------------------------------------
checkpoint_stats = {
key: value for key, value in checkpoint_with_state.items() if key != self._model_key
}
return checkpoint_stats, filename
def restore_latest(self, directory, model_and_loss, include_params="*", exclude_params=()):
latest_checkpoint_filename = os.path.join(
directory, self._prefix + self._latest_postfix + self._extension)
return self.restore(latest_checkpoint_filename, model_and_loss, include_params, exclude_params)
def restore_best(self, directory, model_and_loss, include_params="*", exclude_params=()):
best_checkpoint_filename = os.path.join(
directory, self._prefix + self._best_postfix + self._extension)
return self.restore(best_checkpoint_filename, model_and_loss, include_params, exclude_params)
def save_latest(self, directory, model_and_loss, stats_dict, store_as_best=False):
# -----------------------------------------------------------------------------------------
# Make sure directory exists
# -----------------------------------------------------------------------------------------
tools.ensure_dir(directory)
# -----------------------------------------------------------------------------------------
# Save
# -----------------------------------------------------------------------------------------
save_dict = dict(stats_dict)
save_dict[self._model_key] = model_and_loss.state_dict()
latest_checkpoint_filename = os.path.join(
directory, self._prefix + self._latest_postfix + self._extension)
latest_statistics_filename = os.path.join(
directory, self._prefix + self._latest_postfix + ".json")
torch.save(save_dict, latest_checkpoint_filename)
tools.write_json(data_dict=stats_dict, filename=latest_statistics_filename)
# -----------------------------------------------------------------------------------------
# Possibly store as best
# -----------------------------------------------------------------------------------------
if store_as_best:
best_checkpoint_filename = os.path.join(
directory, self._prefix + self._best_postfix + self._extension)
best_statistics_filename = os.path.join(
directory, self._prefix + self._best_postfix + ".json")
logging.info("Saved checkpoint as best model..")
shutil.copyfile(latest_checkpoint_filename, best_checkpoint_filename)
shutil.copyfile(latest_statistics_filename, best_statistics_filename)
def configure_checkpoint_saver(args, model_and_loss):
with logger.LoggingBlock("Checkpoint", emph=True):
checkpoint_saver = CheckpointSaver()
checkpoint_stats = None
if args.checkpoint is None:
logging.info("No checkpoint given.")
logging.info("Starting from scratch with random initialization.")
elif os.path.isfile(args.checkpoint):
checkpoint_stats, filename = checkpoint_saver.restore(
filename=args.checkpoint,
model_and_loss=model_and_loss,
include_params=args.checkpoint_include_params,
exclude_params=args.checkpoint_exclude_params)
elif os.path.isdir(args.checkpoint):
if args.checkpoint_mode in ["resume_from_best"]:
logging.info("Loading best checkpoint in %s" % args.checkpoint)
checkpoint_stats, filename = checkpoint_saver.restore_best(
directory=args.checkpoint,
model_and_loss=model_and_loss,
include_params=args.checkpoint_include_params,
exclude_params=args.checkpoint_exclude_params)
elif args.checkpoint_mode in ["resume_from_latest"]:
logging.info("Loading latest checkpoint in %s" % args.checkpoint)
checkpoint_stats, filename = checkpoint_saver.restore_latest(
directory=args.checkpoint,
model_and_loss=model_and_loss,
include_params=args.checkpoint_include_params,
exclude_params=args.checkpoint_exclude_params)
else:
logging.info("Unknown checkpoint_restore '%s' given!" % args.checkpoint_restore)
quit()
else:
logging.info("Could not find checkpoint file or directory '%s'" % args.checkpoint)
quit()
return checkpoint_saver, checkpoint_stats
# -------------------------------------------------------------------------------------------------
# Configure data loading
# -------------------------------------------------------------------------------------------------
def configure_data_loaders(args):
with logger.LoggingBlock("Datasets", emph=True):
def _sizes_to_str(value):
if np.isscalar(value):
return '[1L]'
else:
return ' '.join([str([d for d in value.size()])])
def _log_statistics(dataset, prefix, name):
with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)):
example_dict = dataset[0] # get sizes from first dataset example
for key, value in sorted(example_dict.items()):
if key in ["index", "basename"]: # no need to display these
continue
if isinstance(value, str):
logging.info("{}: {}".format(key, value))
else:
logging.info("%s: %s" % (key, _sizes_to_str(value)))
logging.info("num_examples: %i" % len(dataset))
# -----------------------------------------------------------------------------------------
# GPU parameters -- turning off pin_memory? for resolving the deadlock?
# -----------------------------------------------------------------------------------------
gpuargs = {"num_workers": args.num_workers, "pin_memory": False} if args.cuda else {}
train_loader = None
validation_loader = None
inference_loader = None
# -----------------------------------------------------------------------------------------
# Training dataset
# -----------------------------------------------------------------------------------------
if args.training_dataset is not None:
# ----------------------------------------------
# Figure out training_dataset arguments
# ----------------------------------------------
kwargs = tools.kwargs_from_args(args, "training_dataset")
kwargs["is_cropped"] = True
kwargs["args"] = args
# ----------------------------------------------
# Create training dataset
# ----------------------------------------------
train_dataset = tools.instance_from_kwargs(args.training_dataset_class, kwargs)
# ----------------------------------------------
# Create training loader
# ----------------------------------------------
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
**gpuargs)
_log_statistics(train_dataset, prefix="Training", name=args.training_dataset)
# -----------------------------------------------------------------------------------------
# Validation dataset
# -----------------------------------------------------------------------------------------
if args.validation_dataset is not None:
# ----------------------------------------------
# Figure out validation_dataset arguments
# ----------------------------------------------
kwargs = tools.kwargs_from_args(args, "validation_dataset")
kwargs["is_cropped"] = True
kwargs["args"] = args
# ----------------------------------------------
# Create validation dataset
# ----------------------------------------------
validation_dataset = tools.instance_from_kwargs(args.validation_dataset_class, kwargs)
# ----------------------------------------------
# Create validation loader
# ----------------------------------------------
validation_loader = DataLoader(
validation_dataset,
batch_size=args.batch_size_val,
shuffle=False,
drop_last=False,
**gpuargs)
_log_statistics(validation_dataset, prefix="Validation", name=args.validation_dataset)
return train_loader, validation_loader, inference_loader
# ------------------------------------------------------------
# Generator for trainable parameters by pattern matching
# ------------------------------------------------------------
def _print_trainable_params(model_and_loss, match="*"):
sum = 0
for name, p in model_and_loss.named_parameters():
if fnmatch.fnmatch(name, match):
if p.requires_grad:
logging.info(name)
logging.info(str(p.numel()))
print(name)
print(p.numel())
sum += p.numel()
logging.info(str(sum))
def _generate_trainable_params(model_and_loss, match="*"):
for name, p in model_and_loss.named_parameters():
if fnmatch.fnmatch(name, match):
if p.requires_grad:
yield p
def _param_names_and_trainable_generator(model_and_loss, match="*"):
names = []
for name, p in model_and_loss.named_parameters():
if fnmatch.fnmatch(name, match):
if p.requires_grad:
names.append(name)
return names, _generate_trainable_params(model_and_loss, match=match)
# -------------------------------------------------------------------------------------------------
# Build optimizer:
# -------------------------------------------------------------------------------------------------
def configure_optimizer(args, model_and_loss):
optimizer = None
with logger.LoggingBlock("Optimizer", emph=True):
if args.optimizer is not None:
if model_and_loss.num_parameters() == 0:
logging.info("No trainable parameters detected.")
logging.info("Setting optimizer to None.")
else:
logging.info(args.optimizer)
# -------------------------------------------
# Figure out all optimizer arguments
# -------------------------------------------
all_kwargs = tools.kwargs_from_args(args, "optimizer")
# -------------------------------------------
# Get the split of param groups
# -------------------------------------------
kwargs_without_groups = {
key: value for key,value in all_kwargs.items() if key != "group"
}
param_groups = all_kwargs["group"]
# ----------------------------------------------------------------------
# Print arguments (without groups)
# ----------------------------------------------------------------------
for param, default in sorted(kwargs_without_groups.items()):
logging.info("%s: %s" % (param, default))
# ----------------------------------------------------------------------
# Construct actual optimizer params
# ----------------------------------------------------------------------
kwargs = dict(kwargs_without_groups)
if param_groups is None:
# ---------------------------------------------------------
# Add all trainable parameters if there is no param groups
# ---------------------------------------------------------
all_trainable_parameters = _generate_trainable_params(model_and_loss)
kwargs["params"] = all_trainable_parameters
else:
# -------------------------------------------
# Add list of parameter groups instead
# -------------------------------------------
trainable_parameter_groups = []
dnames, dparams = _param_names_and_trainable_generator(model_and_loss)
dnames = set(dnames)
dparams = set(list(dparams))
with logger.LoggingBlock("parameter_groups:"):
for group in param_groups:
# log group settings
group_match = group["params"]
group_args = {
key: value for key, value in group.items() if key != "params"
}
with logger.LoggingBlock("%s: %s" % (group_match, group_args)):
# retrieve parameters by matching name
gnames, gparams = _param_names_and_trainable_generator(
model_and_loss, match=group_match)
# log all names affected
for n in sorted(gnames):
logging.info(n)
# set generator for group
group_args["params"] = gparams
# append parameter group
trainable_parameter_groups.append(group_args)
# update remaining trainable parameters
dnames -= set(gnames)
dparams -= set(list(gparams))
# append default parameter group
trainable_parameter_groups.append({"params": list(dparams)})
# and log its parameter names
with logger.LoggingBlock("default:"):
for dname in sorted(dnames):
logging.info(dname)
# set params in optimizer kwargs
kwargs["params"] = trainable_parameter_groups
# -------------------------------------------
# Create optimizer instance
# -------------------------------------------
optimizer = tools.instance_from_kwargs(args.optimizer_class, kwargs)
return optimizer
# -------------------------------------------------------------------------------------------------
# Configure learning rate scheduler
# -------------------------------------------------------------------------------------------------
def configure_lr_scheduler(args, optimizer):
lr_scheduler = None
with logger.LoggingBlock("Learning Rate Scheduler", emph=True):
logging.info("class: %s" % args.lr_scheduler)
if args.lr_scheduler is not None:
# ----------------------------------------------
# Figure out lr_scheduler arguments
# ----------------------------------------------
kwargs = tools.kwargs_from_args(args, "lr_scheduler")
# -------------------------------------------
# Print arguments
# -------------------------------------------
for param, default in sorted(kwargs.items()):
logging.info("%s: %s" % (param, default))
# -------------------------------------------
# Add optimizer
# -------------------------------------------
kwargs["optimizer"] = optimizer
# -------------------------------------------
# Create lr_scheduler instance
# -------------------------------------------
lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class, kwargs)
return lr_scheduler
================================================
FILE: datasets/__init__.py
================================================
from . import flyingchairs
from . import flyingchairsOcc
from . import sintel
from . import flyingThings3D
from . import kitti_combined
from . import sintel
## FlyingChairs
FlyingChairsTrain = flyingchairs.FlyingChairsTrain
FlyingChairsValid = flyingchairs.FlyingChairsValid
FlyingChairsFull = flyingchairs.FlyingChairsFull
## Our custom FlyingChairs + Occ
FlyingChairsOccTrain = flyingchairsOcc.FlyingChairsOccTrain
FlyingChairsOccValid = flyingchairsOcc.FlyingChairsOccValid
FlyingChairsOccFull = flyingchairsOcc.FlyingChairsOccFull
## FlyingThings3D_subset
FlyingThings3dFinalTrain = flyingThings3D.FlyingThings3dFinalTrain
FlyingThings3dFinalTest = flyingThings3D.FlyingThings3dFinalTest
FlyingThings3dCleanTrain = flyingThings3D.FlyingThings3dCleanTrain
FlyingThings3dCleanTest = flyingThings3D.FlyingThings3dCleanTest
## Sintel
SintelTestClean = sintel.SintelTestClean
SintelTestFinal = sintel.SintelTestFinal
SintelTrainingCombFull = sintel.SintelTrainingCombFull
SintelTrainingCombTrain = sintel.SintelTrainingCombTrain
SintelTrainingCombValid = sintel.SintelTrainingCombValid
SintelTrainingCleanFull = sintel.SintelTrainingCleanFull
SintelTrainingCleanTrain = sintel.SintelTrainingCleanTrain
SintelTrainingCleanValid = sintel.SintelTrainingCleanValid
SintelTrainingFinalFull = sintel.SintelTrainingFinalFull
SintelTrainingFinalTrain = sintel.SintelTrainingFinalTrain
SintelTrainingFinalValid = sintel.SintelTrainingFinalValid
## KITTI Optical Flow 2012 + 2015
KittiCombTrain = kitti_combined.KittiCombTrain
KittiCombVal = kitti_combined.KittiCombVal
KittiCombFull = kitti_combined.KittiCombFull
KittiComb2012Train = kitti_combined.KittiComb2012Train
KittiComb2012Val = kitti_combined.KittiComb2012Val
KittiComb2012Full = kitti_combined.KittiComb2012Full
KittiComb2012Test = kitti_combined.KittiComb2012Test
KittiComb2015Train = kitti_combined.KittiComb2015Train
KittiComb2015Val = kitti_combined.KittiComb2015Val
KittiComb2015Full = kitti_combined.KittiComb2015Full
KittiComb2015Test = kitti_combined.KittiComb2015Test
================================================
FILE: datasets/common.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import torch
import numpy as np
import skimage.io as io
def numpy2torch(array):
assert(isinstance(array, np.ndarray))
if array.ndim == 3:
array = np.transpose(array, (2, 0, 1))
else:
array = np.expand_dims(array, axis=0)
return torch.from_numpy(array.copy()).float()
def read_flo_as_float32(filename):
with open(filename, 'rb') as file:
magic = np.fromfile(file, np.float32, count=1)
assert(202021.25 == magic), "Magic number incorrect. Invalid .flo file"
w = np.fromfile(file, np.int32, count=1)[0]
h = np.fromfile(file, np.int32, count=1)[0]
data = np.fromfile(file, np.float32, count=2*h*w)
data2D = np.resize(data, (h, w, 2))
return data2D
def read_occ_image_as_float32(filename):
occ = io.imread(filename).astype(np.float32) / np.float32(255.0)
if occ.ndim == 3:
occ = occ[:, :, 0]
return occ
def read_image_as_float32(filename):
return io.imread(filename).astype(np.float32) / np.float32(255.0)
def read_image_as_byte(filename):
return io.imread(filename)
================================================
FILE: datasets/flyingThings3D.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
import numpy as np
def fillingInNaN(flow):
h, w, c = flow.shape
indices = np.argwhere(np.isnan(flow))
neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]]
for ii, idx in enumerate(indices):
sum_sample = 0
count = 0
for jj in range(0, len(neighbors) - 1):
hh = idx[0] + neighbors[jj][0]
ww = idx[1] + neighbors[jj][1]
if hh < 0 or hh >= h:
continue
if ww < 0 or ww >= w:
continue
sample_flow = flow[hh, ww, idx[2]]
if np.isnan(sample_flow):
continue
sum_sample += sample_flow
count += 1
if count is 0:
print('FATAL ERROR: no sample')
flow[idx[0], idx[1], idx[2]] = sum_sample / count
return flow
class FlyingThings3d(data.Dataset):
def __init__(self,
args,
images_root,
flow_root,
occ_root,
photometric_augmentations=False):
self._args = args
if not os.path.isdir(images_root):
raise ValueError("Image directory '%s' not found!")
if flow_root is not None and not os.path.isdir(flow_root):
raise ValueError("Flow directory '%s' not found!")
if occ_root is not None and not os.path.isdir(occ_root):
raise ValueError("Occ directory '%s' not found!")
if flow_root is not None:
flow_f_filenames = sorted(glob(os.path.join(flow_root, "into_future/*.flo")))
flow_b_filenames = sorted(glob(os.path.join(flow_root, "into_past/*.flo")))
if occ_root is not None:
occ1_filenames = sorted(glob(os.path.join(occ_root, "into_future/*.png")))
occ2_filenames = sorted(glob(os.path.join(occ_root, "into_past/*.png")))
all_img_filenames = sorted(glob(os.path.join(images_root, "*.png")))
self._image_list = []
self._flow_list = [] if flow_root is not None else None
self._occ_list = [] if occ_root is not None else None
assert len(all_img_filenames) != 0
assert len(flow_f_filenames) != 0
assert len(flow_b_filenames) != 0
assert len(occ1_filenames) != 0
assert len(occ2_filenames) != 0
## path definition
path_flow_f = os.path.join(flow_root, "into_future")
path_flow_b = os.path.join(flow_root, "into_past")
path_occ_f = os.path.join(occ_root, "into_future")
path_occ_b = os.path.join(occ_root, "into_past")
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
for ii in range(0, len(flow_f_filenames)):
flo_f = flow_f_filenames[ii]
idx_f = os.path.splitext(os.path.basename(flo_f))[0]
idx_b = str(int(idx_f) + 1).zfill(7)
flo_b = os.path.join(path_flow_b, idx_b + ".flo")
im1 = os.path.join(images_root, idx_f + ".png")
im2 = os.path.join(images_root, idx_b + ".png")
occ1 = os.path.join(path_occ_f, idx_f + ".png")
occ2 = os.path.join(path_occ_b, idx_b + ".png")
if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile(
im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2):
continue
self._image_list += [[im1, im2]]
self._flow_list += [[flo_f, flo_b]]
self._occ_list += [[occ1, occ2]]
self._size = len(self._image_list)
assert len(self._image_list) == len(self._flow_list)
assert len(self._occ_list) == len(self._flow_list)
assert len(self._image_list) != 0
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_f_filename = self._flow_list[index][0]
flo_b_filename = self._flow_list[index][1]
occ1_filename = self._occ_list[index][0]
occ2_filename = self._occ_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_f_np0 = common.read_flo_as_float32(flo_f_filename)
flo_b_np0 = common.read_flo_as_float32(flo_b_filename)
occ1_np0 = common.read_occ_image_as_float32(occ1_filename)
occ2_np0 = common.read_occ_image_as_float32(occ2_filename)
# temp - check isnan
if np.any(np.isnan(flo_f_np0)):
flo_f_np0 = fillingInNaN(flo_f_np0)
if np.any(np.isnan(flo_b_np0)):
flo_b_np0 = fillingInNaN(flo_b_np0)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo_f = common.numpy2torch(flo_f_np0)
flo_b = common.numpy2torch(flo_b_np0)
# convert occ to FloatTensor
occ1 = common.numpy2torch(occ1_np0)
occ2 = common.numpy2torch(occ2_np0)
# example filename
basename = os.path.basename(im1_filename)[:5]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo_f,
"target2": flo_b,
"target_occ1": occ1,
"target_occ2": occ2,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class FlyingThings3dFinalTrain(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=True):
images_root = os.path.join(root, "frames_finalpass")
flow_root = os.path.join(root, "optical_flow")
occ_root = os.path.join(root, "occlusion")
super(FlyingThings3dFinalTrain, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
class FlyingThings3dFinalTest(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=False):
images_root = os.path.join(root, "frames_finalpass")
flow_root = os.path.join(root, "optical_flow")
occ_root = os.path.join(root, "occlusion")
super(FlyingThings3dFinalTest, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
class FlyingThings3dCleanTrain(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=True):
images_root = os.path.join(root, "train", "image_clean", "left")
flow_root = os.path.join(root, "train", "flow", "left")
occ_root = os.path.join(root, "train", "flow_occlusions", "left")
super(FlyingThings3dCleanTrain, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
class FlyingThings3dCleanTest(FlyingThings3d):
def __init__(self,
args,
root,
photometric_augmentations=False):
images_root = os.path.join(root, "frames_cleanpass")
flow_root = os.path.join(root, "optical_flow")
occ_root = os.path.join(root, "occlusion")
super(FlyingThings3dCleanTest, self).__init__(
args,
images_root=images_root,
flow_root=flow_root,
occ_root=occ_root,
photometric_augmentations=photometric_augmentations)
================================================
FILE: datasets/flyingchairs.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
VALIDATE_INDICES = [
5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,
152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,
337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,
531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,
810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,
1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,
1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,
1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,
1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,
1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,
2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,
2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,
2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,
2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,
2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,
3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,
3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,
3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,
3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,
3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,
3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,
4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,
4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,
4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,
4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,
4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,
4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,
5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,
5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,
5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,
5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,
5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,
5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,
6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,
6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,
6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,
6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,
6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,
6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,
7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,
7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,
7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,
7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,
7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,
8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,
8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,
8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,
8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,
8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,
9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,
9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,
9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,
9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,
9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,
9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,
10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,
10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,
10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,
10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,
10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,
12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,
13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,
15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,
17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,
21802, 21803, 21806, 22584, 22857, 22858, 22866]
class FlyingChairs(data.Dataset):
def __init__(self,
args,
root,
photometric_augmentations=False,
dstype="train"):
self._args = args
# -------------------------------------------------------------
# filenames for all input images and target flows
# -------------------------------------------------------------
image_filenames = sorted( glob( os.path.join(root, "*.ppm")) )
flow_filenames = sorted( glob( os.path.join(root, "*.flo")) )
assert (len(image_filenames)/2 == len(flow_filenames))
num_flows = len(flow_filenames)
# -------------------------------------------------------------
# Remove invalid validation indices
# -------------------------------------------------------------
validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]
# ----------------------------------------------------------
# Construct list of indices for training/validation
# ----------------------------------------------------------
list_of_indices = None
if dstype == "train":
list_of_indices = [x for x in range(num_flows) if x not in validate_indices]
elif dstype == "valid":
list_of_indices = validate_indices
elif dstype == "full":
list_of_indices = range(num_flows)
else:
raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
for i in list_of_indices:
flo = flow_filenames[i]
im1 = image_filenames[2*i]
im2 = image_filenames[2*i + 1]
self._image_list += [ [ im1, im2 ] ]
self._flow_list += [ flo ]
self._size = len(self._image_list)
assert len(self._image_list) == len(self._flow_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_filename = self._flow_list[index]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_np0 = common.read_flo_as_float32(flo_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo = common.numpy2torch(flo_np0)
# target_occ: initialized by zero (not used)
target_occ = common.numpy2torch(common.read_occ_image_as_float32(im1_filename)) * 0
# example filename
basename = os.path.basename(im1_filename)[:5]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo,
"target_occ1": target_occ,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class FlyingChairsTrain(FlyingChairs):
def __init__(self,
args,
root,
photometric_augmentations=True):
super(FlyingChairsTrain, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="train")
class FlyingChairsValid(FlyingChairs):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsValid, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="valid")
class FlyingChairsFull(FlyingChairs):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsFull, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="full")
================================================
FILE: datasets/flyingchairsOcc.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
VALIDATE_INDICES = [
5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,
152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,
337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,
531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,
810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,
1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,
1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,
1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,
1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,
1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,
2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,
2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,
2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,
2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,
2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,
3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,
3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,
3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,
3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,
3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,
3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,
4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,
4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,
4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,
4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,
4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,
4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,
5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,
5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,
5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,
5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,
5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,
5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,
6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,
6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,
6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,
6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,
6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,
6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,
7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,
7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,
7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,
7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,
7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,
8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,
8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,
8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,
8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,
8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,
9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,
9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,
9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,
9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,
9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,
9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,
10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,
10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,
10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,
10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,
10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,
12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,
13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,
15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,
17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,
21802, 21803, 21806, 22584, 22857, 22858, 22866]
class FlyingChairsOcc(data.Dataset):
def __init__(self,
args,
root,
photometric_augmentations=False,
dstype="train"):
self._args = args
# -------------------------------------------------------------
# filenames for all input images and target flows
# -------------------------------------------------------------
image1_filenames = sorted(glob(os.path.join(root, "*_img1.png")))
image2_filenames = sorted(glob(os.path.join(root, "*_img2.png")))
occ1_filenames = sorted(glob(os.path.join(root, "*_occ1.png")))
occ2_filenames = sorted(glob(os.path.join(root, "*_occ2.png")))
flow_f_filenames = sorted(glob(os.path.join(root, "*_flow.flo")))
flow_b_filenames = sorted(glob(os.path.join(root, "*_flow_b.flo")))
assert (len(image1_filenames) == len(image2_filenames))
assert (len(image2_filenames) == len(occ1_filenames))
assert (len(occ1_filenames) == len(occ2_filenames))
assert (len(occ2_filenames) == len(flow_f_filenames))
assert (len(flow_f_filenames) == len(flow_b_filenames))
num_flows = len(flow_f_filenames)
# -------------------------------------------------------------
# Remove invalid validation indices
# -------------------------------------------------------------
validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]
# ----------------------------------------------------------
# Construct list of indices for training/validation
# ----------------------------------------------------------
list_of_indices = None
if dstype == "train":
list_of_indices = [x for x in range(num_flows) if x not in validate_indices]
elif dstype == "valid":
list_of_indices = validate_indices
elif dstype == "full":
list_of_indices = range(num_flows)
else:
raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
self._occ_list = []
for i in list_of_indices:
flo_f = flow_f_filenames[i]
flo_b = flow_b_filenames[i]
im1 = image1_filenames[i]
im2 = image2_filenames[i]
occ1 = occ1_filenames[i]
occ2 = occ2_filenames[i]
self._image_list += [[im1, im2]]
self._flow_list += [[flo_f, flo_b]]
self._occ_list += [[occ1, occ2]]
self._size = len(self._image_list)
assert len(self._image_list) == len(self._flow_list)
assert len(self._occ_list) == len(self._flow_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_f_filename = self._flow_list[index][0]
flo_b_filename = self._flow_list[index][1]
occ1_filename = self._occ_list[index][0]
occ2_filename = self._occ_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_f_np0 = common.read_flo_as_float32(flo_f_filename)
flo_b_np0 = common.read_flo_as_float32(flo_b_filename)
occ1_np0 = common.read_occ_image_as_float32(occ1_filename)
occ2_np0 = common.read_occ_image_as_float32(occ2_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo_f = common.numpy2torch(flo_f_np0)
flo_b = common.numpy2torch(flo_b_np0)
# convert occ to FloatTensor
occ1 = common.numpy2torch(occ1_np0)
occ2 = common.numpy2torch(occ2_np0)
# example filename
basename = os.path.basename(im1_filename)[:5]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo_f,
"target2": flo_b,
"target_occ1": occ1,
"target_occ2": occ2,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class FlyingChairsOccTrain(FlyingChairsOcc):
def __init__(self,
args,
root,
photometric_augmentations=True):
super(FlyingChairsOccTrain, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="train")
class FlyingChairsOccValid(FlyingChairsOcc):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsOccValid, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="valid")
class FlyingChairsOccFull(FlyingChairsOcc):
def __init__(self,
args,
root,
photometric_augmentations=False):
super(FlyingChairsOccFull, self).__init__(
args,
root=root,
photometric_augmentations=photometric_augmentations,
dstype="full")
================================================
FILE: datasets/kitti_combined.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
import numpy as np
import png
VALIDATE_INDICES_2015 = [10, 11, 12, 25, 26, 30, 31, 40, 41, 42, 46, 52, 53, 72, 73, 74, 75, 76, 80, 81, 85, 86, 95, 96, 97, 98, 104, 116, 117, 120, 121, 126, 127, 153, 172, 175, 183, 184, 190, 199]
VALIDATE_INDICES_2012 = [0, 12, 15, 16, 17, 18, 24, 30, 38, 39, 42, 50, 54, 59, 60, 61, 77, 78, 81, 89, 97, 101, 107, 121, 124, 142, 145, 146, 152, 154, 155, 158, 159, 160, 164, 182, 183, 184, 190]
def read_png_flow(flow_file):
flow_object = png.Reader(filename=flow_file)
flow_direct = flow_object.asDirect()
flow_data = list(flow_direct[2])
(w, h) = flow_direct[3]['size']
flow = np.zeros((h, w, 3), dtype=np.float64)
for i in range(len(flow_data)):
flow[i, :, 0] = flow_data[i][0::3]
flow[i, :, 1] = flow_data[i][1::3]
flow[i, :, 2] = flow_data[i][2::3]
invalid_idx = (flow[:, :, 2] == 0)
flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
flow[invalid_idx, 0] = 0
flow[invalid_idx, 1] = 0
return flow[:, :, 0:2], (1 - invalid_idx * 1)[:, :, None]
def kitti_random_crop(im1, im2, flo_f, valid_mask, crop_height=370, crop_width=1224):
height, width, _ = im1.shape
# get starting positions
x = np.random.uniform(0, width - crop_width + 1)
y = np.random.uniform(0, height - crop_height + 1)
str_x = int(x)
str_y = int(y)
end_x = int(x + crop_width)
end_y = int(y + crop_height)
im1 = im1[str_y:end_y, str_x:end_x, :]
im2 = im2[str_y:end_y, str_x:end_x, :]
flo_f = flo_f[str_y:end_y, str_x:end_x, :]
valid_mask = valid_mask[str_y:end_y, str_x:end_x, :]
return im1, im2, flo_f, valid_mask
class Kitti_comb_test(data.Dataset):
def __init__(self,
args,
images_root_2015=None,
images_root_2012=None,
photometric_augmentations=False,
preprocessing_crop=True):
self._args = args
self.preprocessing_crop = preprocessing_crop
list_of_indices_2012 = []
list_of_indices_2015 = []
# ----------------------------------------------------------
# KITTI 2015
# ----------------------------------------------------------
if images_root_2015 is not None:
if not os.path.isdir(images_root_2015):
raise ValueError("Image directory not found! {}".format(images_root_2015))
all_img1_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_10.png")))
all_img2_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_11.png")))
assert len(all_img1_2015_filenames) != 0
assert len(all_img2_2015_filenames) == len(all_img1_2015_filenames)
list_of_indices_2015 = range(len(all_img1_2015_filenames))
# ----------------------------------------------------------
# KITTI 2012
# ----------------------------------------------------------
if images_root_2012 is not None:
if not os.path.isdir(images_root_2012):
raise ValueError("Image directory not found! {}".format(images_root_2012))
all_img1_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_10.png")))
all_img2_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_11.png")))
assert len(all_img1_2012_filenames) != 0
assert len(all_img2_2012_filenames) == len(all_img1_2012_filenames)
list_of_indices_2012 = range(len(all_img1_2012_filenames))
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
for ii in list_of_indices_2015:
im1 = all_img1_2015_filenames[ii]
im2 = all_img2_2015_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
for ii in list_of_indices_2012:
im1 = all_img1_2012_filenames[ii]
im2 = all_img2_2012_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
self._size = len(self._image_list)
assert len(self._image_list) != 0
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# example filename
basename = os.path.basename(im1_filename)[:6]
example_dict = {
"input1": im1,
"input2": im2,
"index": index,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class Kitti_comb(data.Dataset):
def __init__(self,
args,
images_root_2015=None,
flow_root_2015=None,
images_root_2012=None,
flow_root_2012=None,
photometric_augmentations=False,
preprocessing_crop=True,
dstype="full"):
self._args = args
self.preprocessing_crop = preprocessing_crop
list_of_indices_2012 = []
list_of_indices_2015 = []
# ----------------------------------------------------------
# KITTI 2015
# ----------------------------------------------------------
if images_root_2015 is not None and flow_root_2015 is not None:
if not os.path.isdir(images_root_2015):
raise ValueError("Image directory not found! {}".format(images_root_2015))
if not os.path.isdir(flow_root_2015):
raise ValueError("Flow directory not found! {}".format(flow_root_2015))
all_img1_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_10.png")))
all_img2_2015_filenames = sorted(glob(os.path.join(images_root_2015, "*_11.png")))
flow_f_2015_filenames = sorted(glob(os.path.join(flow_root_2015, "*_10.png")))
assert len(all_img1_2015_filenames) != 0
assert len(all_img2_2015_filenames) == len(all_img1_2015_filenames)
assert len(flow_f_2015_filenames) == len(all_img1_2015_filenames)
num_flows_2015 = len(flow_f_2015_filenames)
validate_indices_2015 = [x for x in VALIDATE_INDICES_2015 if x in range(num_flows_2015)]
if dstype == "train":
list_of_indices_2015 = [x for x in range(num_flows_2015) if x not in validate_indices_2015]
elif dstype == "valid":
list_of_indices_2015 = validate_indices_2015
elif dstype == "full":
list_of_indices_2015 = range(len(all_img1_2015_filenames))
else:
raise ValueError("KITTI 2015: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# KITTI 2012
# ----------------------------------------------------------
if images_root_2012 is not None:
if not os.path.isdir(images_root_2012):
raise ValueError("Image directory '%s' not found!")
if not os.path.isdir(flow_root_2012):
raise ValueError("Flow directory '%s' not found!")
all_img1_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_10.png")))
all_img2_2012_filenames = sorted(glob(os.path.join(images_root_2012, "*_11.png")))
flow_f_2012_filenames = sorted(glob(os.path.join(flow_root_2012, "*_10.png")))
assert len(all_img1_2012_filenames) != 0
assert len(all_img2_2012_filenames) == len(all_img1_2012_filenames)
assert len(flow_f_2012_filenames) == len(all_img1_2012_filenames)
num_flows_2012 = len(flow_f_2012_filenames)
validate_indices_2012 = [x for x in VALIDATE_INDICES_2012 if x in range(num_flows_2012)]
if dstype == "train":
list_of_indices_2012 = [x for x in range(num_flows_2012) if x not in validate_indices_2012]
elif dstype == "valid":
list_of_indices_2012 = validate_indices_2012
elif dstype == "full":
list_of_indices_2012 = range(len(all_img1_2012_filenames))
else:
raise ValueError("KITTI 2012: dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = []
self._flow_list = []
for ii in list_of_indices_2015:
im1 = all_img1_2015_filenames[ii]
im2 = all_img2_2015_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
if dstype is not "test":
flo_f = flow_f_2015_filenames[ii]
idx_f = os.path.splitext(os.path.basename(flo_f))[0][:-3]
assert idx1 == idx_f
if not os.path.isfile(flo_f):
continue
self._flow_list += [[flo_f]]
for ii in list_of_indices_2012:
im1 = all_img1_2012_filenames[ii]
im2 = all_img2_2012_filenames[ii]
idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]
idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]
assert idx1 == idx2
if not os.path.isfile(im1) or not os.path.isfile(im2):
continue
self._image_list += [[im1, im2]]
if dstype is not "test":
flo_f = flow_f_2012_filenames[ii]
idx_f = os.path.splitext(os.path.basename(flo_f))[0][:-3]
assert idx1 == idx_f
if not os.path.isfile(flo_f):
continue
self._flow_list += [[flo_f]]
self._size = len(self._image_list)
assert len(self._image_list) != 0
if dstype is not "test":
assert len(self._image_list) == len(self._flow_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_f_filename = self._flow_list[index][0]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_f_np0, valid_mask = read_png_flow(flo_f_filename)
if self.preprocessing_crop:
im1_np0, im2_np0, flo_f_np0, valid_mask = kitti_random_crop(im1_np0, im2_np0, flo_f_np0, valid_mask)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# convert flow to FloatTensor
flo_f = common.numpy2torch(flo_f_np0)
valid_mask_f = common.numpy2torch(valid_mask)
# example filename
basename = os.path.basename(im1_filename)[:6]
example_dict = {
"input1": im1,
"input2": im2,
"target1": flo_f,
"target2": flo_f,
"index": index,
"basename": basename,
"input_valid": valid_mask_f
}
return example_dict
def __len__(self):
return self._size
class KittiCombTrain(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiCombTrain, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="train")
class KittiCombVal(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiCombVal, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="valid")
class KittiCombFull(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiCombFull, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="full")
class KittiComb2015Train(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
super(KittiComb2015Train, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="train")
class KittiComb2015Val(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
super(KittiComb2015Val, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="valid")
class KittiComb2015Full(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2015 = os.path.join(root, "data_scene_flow", "training", "image_2")
flow_root_2015 = os.path.join(root, "data_scene_flow", "training", "flow_occ")
super(KittiComb2015Full, self).__init__(
args,
images_root_2015=images_root_2015,
flow_root_2015=flow_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="full")
class KittiComb2015Test(Kitti_comb_test):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2015 = os.path.join(root, "data_scene_flow", "testing", "image_2")
super(KittiComb2015Test, self).__init__(
args,
images_root_2015=images_root_2015,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop)
class KittiComb2012Train(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiComb2012Train, self).__init__(
args,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="train")
class KittiComb2012Val(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiComb2012Val, self).__init__(
args,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="valid")
class KittiComb2012Full(Kitti_comb):
def __init__(self,
args,
root,
photometric_augmentations=True,
preprocessing_crop=True):
images_root_2012 = os.path.join(root, "data_stereo_flow", "training", "colored_0")
flow_root_2012 = os.path.join(root, "data_stereo_flow", "training", "flow_occ")
super(KittiComb2012Full, self).__init__(
args,
images_root_2012=images_root_2012,
flow_root_2012=flow_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop,
dstype="full")
class KittiComb2012Test(Kitti_comb_test):
def __init__(self,
args,
root,
photometric_augmentations=False,
preprocessing_crop=False):
images_root_2012 = os.path.join(root, "data_stereo_flow", "testing", "colored_0")
super(KittiComb2012Test, self).__init__(
args,
images_root_2012=images_root_2012,
photometric_augmentations=photometric_augmentations,
preprocessing_crop=preprocessing_crop)
================================================
FILE: datasets/sintel.py
================================================
from __future__ import absolute_import, division, print_function
import os
import torch.utils.data as data
from glob import glob
from torchvision import transforms as vision_transforms
from . import transforms
from . import common
import tools
VALIDATE_INDICES = [
199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 340, 341, 342, 343, 344,
345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356,
357, 358, 359, 360, 361, 362, 363, 364, 536, 537, 538, 539,
540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551,
552, 553, 554, 555, 556, 557, 558, 559, 560, 659, 660, 661,
662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,
674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685,
686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697,
967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978,
979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990,
991]
class _Sintel(data.Dataset):
def __init__(self,
args,
dir_root=None,
photometric_augmentations=False,
imgtype=None,
dstype=None):
self._args = args
images_root = os.path.join(dir_root, imgtype)
if imgtype is "comb":
images_root = os.path.join(dir_root, "clean")
flow_root = os.path.join(dir_root, "flow")
occ_root = os.path.join(dir_root, "occlusions_rev")
if not os.path.isdir(images_root):
raise ValueError("Image directory '%s' not found!")
if flow_root is not None and not os.path.isdir(flow_root):
raise ValueError("Flow directory '%s' not found!")
if occ_root is not None and not os.path.isdir(occ_root):
raise ValueError("Occ directory '%s' not found!")
all_flo_filenames = sorted(glob(os.path.join(flow_root, "*/*.flo")))
all_occ_filenames = sorted(glob(os.path.join(occ_root, "*/*.png")))
all_img_filenames = sorted(glob(os.path.join(images_root, "*/*.png")))
# Remember base for substraction at runtime
# e.g. subtract_base = "/home/user/.../MPI-Sintel-Complete/training/clean"
self._substract_base = tools.cd_dotdot(images_root)
# ------------------------------------------------------------------------
# Get unique basenames
# ------------------------------------------------------------------------
# e.g. base_folders = [alley_1", "alley_2", "ambush_2", ...]
substract_full_base = tools.cd_dotdot(all_img_filenames[0])
base_folders = sorted(list(set([
os.path.dirname(fn.replace(substract_full_base, ""))[1:] for fn in all_img_filenames
])))
self._image_list = []
self._flow_list = []
self._occ_list = []
for base_folder in base_folders:
img_filenames = [x for x in all_img_filenames if base_folder in x]
flo_filenames = [x for x in all_flo_filenames if base_folder in x]
occ_filenames = [x for x in all_occ_filenames if base_folder in x]
for i in range(len(img_filenames) - 1):
im1 = img_filenames[i]
im2 = img_filenames[i + 1]
flo = flo_filenames[i]
occ = occ_filenames[i]
self._image_list += [[im1, im2]]
self._flow_list += [flo]
self._occ_list += [occ]
# Sanity check
im1_base_filename = os.path.splitext(os.path.basename(im1))[0]
im2_base_filename = os.path.splitext(os.path.basename(im2))[0]
flo_base_filename = os.path.splitext(os.path.basename(flo))[0]
occ_base_filename = os.path.splitext(os.path.basename(occ))[0]
im1_frame, im1_no = im1_base_filename.split("_")
im2_frame, im2_no = im2_base_filename.split("_")
assert(im1_frame == im2_frame)
assert(int(im1_no) == int(im2_no) - 1)
flo_frame, flo_no = flo_base_filename.split("_")
assert(im1_frame == flo_frame)
assert(int(im1_no) == int(flo_no))
occ_frame, occ_no = occ_base_filename.split("_")
assert(im1_frame == occ_frame)
assert(int(im1_no) == int(occ_no))
assert len(self._image_list) == len(self._flow_list)
assert len(self._image_list) == len(self._occ_list)
# -------------------------------------------------------------
# Remove invalid validation indices
# -------------------------------------------------------------
full_num_examples = len(self._image_list)
validate_indices = [x for x in VALIDATE_INDICES if x in range(full_num_examples)]
# ----------------------------------------------------------
# Construct list of indices for training/validation
# ----------------------------------------------------------
list_of_indices = None
if dstype == "train":
list_of_indices = [x for x in range(full_num_examples) if x not in validate_indices]
elif dstype == "valid":
list_of_indices = validate_indices
elif dstype == "full":
list_of_indices = range(full_num_examples)
else:
raise ValueError("dstype '%s' unknown!", dstype)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = [self._image_list[i] for i in list_of_indices]
self._flow_list = [self._flow_list[i] for i in list_of_indices]
self._occ_list = [self._occ_list[i] for i in list_of_indices]
if imgtype is "comb":
image_list_final = [[val[0].replace("clean", "final"), val[1].replace("clean", "final")] for idx, val in enumerate(self._image_list)]
self._image_list += image_list_final
self._flow_list += self._flow_list
self._occ_list += self._occ_list
assert len(self._image_list) == len(self._flow_list)
assert len(self._image_list) == len(self._occ_list)
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
self._size = len(self._image_list)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
flo_filename = self._flow_list[index]
occ_filename = self._occ_list[index]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
flo_np0 = common.read_flo_as_float32(flo_filename)
occ_np0 = common.read_occ_image_as_float32(occ_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
flo = common.numpy2torch(flo_np0)
occ = common.numpy2torch(occ_np0)
# e.g. "clean/alley_1/"
basedir = os.path.splitext(os.path.dirname(im1_filename).replace(self._substract_base, "")[1:])[0]
# example filename
basename = os.path.splitext(os.path.basename(im1_filename))[0]
example_dict = {
"input1": im1,
"input2": im2,
"index": index,
"basedir": basedir,
"basename": basename,
"target1": flo,
"target_occ1": occ
}
return example_dict
def __len__(self):
return self._size
class _Sintel_test(data.Dataset):
def __init__(self,
args,
dir_root=None,
photometric_augmentations=False,
imgtype=None):
self._args = args
images_root = os.path.join(dir_root, imgtype)
if not os.path.isdir(images_root):
raise ValueError("Image directory '%s' not found!")
all_img_filenames = sorted(glob(os.path.join(images_root, "*/*.png")))
# Remember base for substraction at runtime
# e.g. subtract_base = "/home/user/.../MPI-Sintel-Complete/training/clean"
self._substract_base = tools.cd_dotdot(images_root)
# ------------------------------------------------------------------------
# Get unique basenames
# ------------------------------------------------------------------------
# e.g. base_folders = [alley_1", "alley_2", "ambush_2", ...]
substract_full_base = tools.cd_dotdot(all_img_filenames[0])
base_folders = sorted(list(set([
os.path.dirname(fn.replace(substract_full_base, ""))[1:] for fn in all_img_filenames
])))
self._image_list = []
for base_folder in base_folders:
img_filenames = [x for x in all_img_filenames if base_folder in x]
for i in range(len(img_filenames) - 1):
im1 = img_filenames[i]
im2 = img_filenames[i + 1]
self._image_list += [[im1, im2]]
# Sanity check
im1_base_filename = os.path.splitext(os.path.basename(im1))[0]
im2_base_filename = os.path.splitext(os.path.basename(im2))[0]
im1_frame, im1_no = im1_base_filename.split("_")
im2_frame, im2_no = im2_base_filename.split("_")
assert(im1_frame == im2_frame)
assert(int(im1_no) == int(im2_no) - 1)
full_num_examples = len(self._image_list)
list_of_indices = range(full_num_examples)
# ----------------------------------------------------------
# Save list of actual filenames for inputs and flows
# ----------------------------------------------------------
self._image_list = [self._image_list[i] for i in list_of_indices]
# ----------------------------------------------------------
# photometric_augmentations
# ----------------------------------------------------------
if photometric_augmentations:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> PIL
vision_transforms.ToPILImage(),
# PIL -> PIL : random hsv and contrast
vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# PIL -> FloatTensor
vision_transforms.transforms.ToTensor(),
transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
], from_numpy=True, to_numpy=False)
else:
self._photometric_transform = transforms.ConcatTransformSplitChainer([
# uint8 -> FloatTensor
vision_transforms.transforms.ToTensor(),
], from_numpy=True, to_numpy=False)
self._size = len(self._image_list)
def __getitem__(self, index):
index = index % self._size
im1_filename = self._image_list[index][0]
im2_filename = self._image_list[index][1]
# read float32 images and flow
im1_np0 = common.read_image_as_byte(im1_filename)
im2_np0 = common.read_image_as_byte(im2_filename)
# possibly apply photometric transformations
im1, im2 = self._photometric_transform(im1_np0, im2_np0)
# e.g. "clean/alley_1/"
basedir = os.path.splitext(os.path.dirname(im1_filename).replace(self._substract_base, "")[1:])[0]
# example filename
basename = os.path.splitext(os.path.basename(im1_filename))[0]
example_dict = {
"input1": im1,
"input2": im2,
"index": index,
"basedir": basedir,
"basename": basename
}
return example_dict
def __len__(self):
return self._size
class SintelTrainingCleanTrain(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCleanTrain, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean",
dstype="train")
class SintelTrainingCleanValid(_Sintel):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "training")
super(SintelTrainingCleanValid, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean",
dstype="valid")
class SintelTrainingCleanFull(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCleanFull, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean",
dstype="full")
class SintelTrainingFinalTrain(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingFinalTrain, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final",
dstype="train")
class SintelTrainingFinalValid(_Sintel):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "training")
super(SintelTrainingFinalValid, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final",
dstype="valid")
class SintelTrainingFinalFull(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingFinalFull, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final",
dstype="full")
class SintelTrainingCombTrain(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCombTrain, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="comb",
dstype="train")
class SintelTrainingCombValid(_Sintel):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "training")
super(SintelTrainingCombValid, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="comb",
dstype="valid")
class SintelTrainingCombFull(_Sintel):
def __init__(self, args, root, photometric_augmentations=True):
dir_root = os.path.join(root, "training")
super(SintelTrainingCombFull, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="comb",
dstype="full")
class SintelTestClean(_Sintel_test):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "test")
super(SintelTestClean, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="clean")
class SintelTestFinal(_Sintel_test):
def __init__(self, args, root, photometric_augmentations=False):
dir_root = os.path.join(root, "test")
super(SintelTestFinal, self).__init__(
args,
dir_root=dir_root,
photometric_augmentations=photometric_augmentations,
imgtype="final")
================================================
FILE: datasets/transforms.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import numpy as np
import torch
def image_random_gamma(image, min_gamma=0.7, max_gamma=1.5, clip_image=False):
gamma = np.random.uniform(min_gamma, max_gamma)
adjusted = torch.pow(image, gamma)
if clip_image:
adjusted.clamp_(0.0, 1.0)
return adjusted
class RandomGamma:
def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False):
self._min_gamma = min_gamma
self._max_gamma = max_gamma
self._clip_image = clip_image
def __call__(self, image):
return image_random_gamma(
image,
min_gamma=self._min_gamma,
max_gamma=self._max_gamma,
clip_image=self._clip_image)
# ------------------------------------------------------------------
# Allow transformation chains of the type:
# im1, im2, .... = transform(im1, im2, ...)
# ------------------------------------------------------------------
class TransformChainer:
def __init__(self, list_of_transforms):
self._list_of_transforms = list_of_transforms
def __call__(self, *args):
list_of_args = list(args)
for transform in self._list_of_transforms:
list_of_args = [transform(arg) for arg in list_of_args]
if len(args) == 1:
return list_of_args[0]
else:
return list_of_args
# ------------------------------------------------------------------
# Allow transformation chains of the type:
# im1, im2, .... = split( transform( concatenate(im1, im2, ...) ))
# ------------------------------------------------------------------
class ConcatTransformSplitChainer:
def __init__(self, list_of_transforms, from_numpy=True, to_numpy=False):
self._chainer = TransformChainer(list_of_transforms)
self._from_numpy = from_numpy
self._to_numpy = to_numpy
def __call__(self, *args):
num_splits = len(args)
if self._from_numpy:
concatenated = np.concatenate(args, axis=0)
else:
concatenated = torch.cat(args, dim=1)
transformed = self._chainer(concatenated)
if self._to_numpy:
split = np.split(transformed, indices_or_sections=num_splits, axis=0)
else:
split = torch.chunk(transformed, num_splits, dim=1)
return split
================================================
FILE: flyingchairsocc/README.md
================================================
# FlyingChairsOcc dataset
<img src=demo_img.png>
The FlyingChairsOcc dataset is an extended version of the <a href="https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html" target="_blank">Flying Chairs Dataset</a>, including bi-directional optical flow ground truth and two occlusion maps for each image.
You may also find that another concurrent dataset, the <a href="https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html" target="_blank">Flying Chairs 2 Dataset</a>, is useful.
## License agreement
This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
1. That the dataset comes “AS IS”, without express or implied warranty. Although every effort has been made to ensure accuracy, we (TU Darmstadt) do not accept any responsibility for errors or omissions.
2. That you include a reference to the FlyingChairsOcc Dataset in any work that makes use of the dataset. For research papers, cite our publication: J. Hur and S. Roth, “Iterative residual refinement for joint optical flow and occlusion estimation,” in Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Long Beach, California, June 2019
3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
5. That all rights not expressly granted to you are reserved by us (TU Darmstadt).
## Download link
<a href="https://download.visinf.tu-darmstadt.de/data/flyingchairs_occ/FlyingChairsOcc.tar.gz" target="_blank"><b>Download</b></a>
(82GB)
## Reference
Please cite the paper below if you find the dataset and source codes are useful.
@inproceedings{Hur:2019:IRR,
Author = {Junhwa Hur and Stefan Roth},
Booktitle = {CVPR},
Title = {Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation},
Year = {2019}
}
Contact: junhwa.hur[at]visinf.tu-darmstadt.de
================================================
FILE: install.sh
================================================
#!/bin/bash
cd ./models/correlation_package
python setup.py install
cd ..
================================================
FILE: logger.py
================================================
## Portions of Code from, copyright 2018 Jochen Gast
from __future__ import absolute_import, division, print_function
import colorama
import logging
import os
import re
import tools
import sys
def get_default_logging_format(colorize=False, brackets=False):
style = colorama.Style.DIM if colorize else ''
# color = colorama.Fore.CYAN if colorize else ''
color = colorama.Fore.WHITE if colorize else ''
reset = colorama.Style.RESET_ALL if colorize else ''
if brackets:
result = "{}{}[%(asctime)s]{} %(message)s".format(style, color, reset)
else:
result = "{}{}%(asctime)s{} %(message)s".format(style, color, reset)
return result
def get_default_logging_datefmt():
return "%Y-%m-%d %H:%M:%S"
def log_module_info(module):
lines = module.__str__().split("\n")
for line in lines:
logging.info(line)
class LogbookFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None):
super(LogbookFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
self._re = re.compile(r"\033\[[0-9]+m")
def remove_colors_from_msg(self, msg):
msg = re.sub(self._re, "", msg)
return msg
def format(self, record=None):
record.msg = self.remove_colors_from_msg(record.msg)
return super(LogbookFormatter, self).format(record)
class ConsoleFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None):
super(ConsoleFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
def format(self, record=None):
indent = sys.modules[__name__].global_indent
record.msg = " " * indent + record.msg
return super(ConsoleFormatter, self).format(record)
class SkipLogbookFilter(logging.Filter):
def filter(self, record):
return record.levelno != logging.LOGBOOK
def configure_logging(filename=None):
# set global indent level
sys.modules[__name__].global_indent = 0
# add custom tqdm logger
tools.addLoggingLevel("LOGBOOK", 1000)
# create logger
root_logger = logging.getLogger("")
root_logger.setLevel(logging.INFO)
# create console handler and set level to debug
console = logging.StreamHandler()
console.setLevel(logging.INFO)
fmt = get_default_logging_format(colorize=True, brackets=False)
datefmt = get_default_logging_datefmt()
formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt)
console.setFormatter(formatter)
# Skip logging.tqdm requests for console outputs
skip_logbook_filter = SkipLogbookFilter()
console.addFilter(skip_logbook_filter)
# add console to root_logger
root_logger.addHandler(console)
# add logbook
if filename is not None:
# ensure dir
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
# --------------------------------------------------------------------------------------
# Configure handler that removes color codes from logbook
# --------------------------------------------------------------------------------------
logbook = logging.FileHandler(filename=filename, mode="a", encoding="utf-8")
logbook.setLevel(logging.INFO)
fmt = get_default_logging_format(colorize=False, brackets=True)
logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt)
logbook.setFormatter(logbook_formatter)
root_logger.addHandler(logbook)
class LoggingBlock:
def __init__(self, title, emph=False):
self._emph = emph
bright = colorama.Style.BRIGHT
cyan = colorama.Fore.CYAN
reset = colorama.Style.RESET_ALL
if emph:
logging.info("%s==>%s %s%s%s" % (cyan, reset, bright, title, reset))
else:
logging.info(title)
def __enter__(self):
sys.modules[__name__].global_indent += 2
return self
def __exit__(self, exc_type, exc_value, traceback):
sys.modules[__name__].global_indent -= 2
================================================
FILE: losses.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as tf
def _elementwise_epe(input_flow, target_flow):
residual = target_flow - input_flow
return torch.norm(residual, p=2, dim=1, keepdim=True)
def _elementwise_robust_epe_char(input_flow, target_flow):
residual = target_flow - input_flow
return torch.pow(torch.norm(residual, p=2, dim=1, keepdim=True) + 0.01, 0.4)
def _downsample2d_as(inputs, target_as):
_, _, h, w = target_as.size()
return tf.adaptive_avg_pool2d(inputs, [h, w])
def _upsample2d_as(inputs, target_as, mode="bilinear"):
_, _, h, w = target_as.size()
return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
def f1_score(y_true, y_pred):
return fbeta_score(y_true, y_pred, 1)
def fbeta_score(y_true, y_pred, beta, eps=1e-8):
beta2 = beta ** 2
y_pred = y_pred.float()
y_true = y_true.float()
true_positive = (y_pred * y_true).sum(dim=2).sum(dim=2)
precision = true_positive / (y_pred.sum(dim=2).sum(dim=2) + eps)
recall = true_positive / (y_true.sum(dim=2).sum(dim=2) + eps)
return torch.mean(precision * recall / (precision * beta2 + recall + eps) * (1 + beta2))
def f1_score_bal_loss(y_pred, y_true):
eps = 1e-8
tp = -(y_true * torch.log(y_pred + eps)).sum(dim=2).sum(dim=2).sum(dim=1)
fn = -((1 - y_true) * torch.log((1 - y_pred) + eps)).sum(dim=2).sum(dim=2).sum(dim=1)
denom_tp = y_true.sum(dim=2).sum(dim=2).sum(dim=1) + y_pred.sum(dim=2).sum(dim=2).sum(dim=1) + eps
denom_fn = (1 - y_true).sum(dim=2).sum(dim=2).sum(dim=1) + (1 - y_pred).sum(dim=2).sum(dim=2).sum(dim=1) + eps
return ((tp / denom_tp).sum() + (fn / denom_fn).sum()) * y_pred.size(2) * y_pred.size(3) * 0.5
class MultiScaleEPE_FlowNet(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
# div_flow trick
target = self._args.model_div_flow * target_dict["target1"]
total_loss = 0
for i, output_i in enumerate(outputs):
target_i = _downsample2d_as(target, output_i)
epe_i = _elementwise_epe(output_i, target_i)
total_loss = total_loss + self._weights[i] * epe_i.sum() / self._batch_size
loss_dict["epe%i" % (i + 2)] = epe_i.mean()
loss_dict["total_loss"] = total_loss
else:
output = output_dict["flow1"]
target = target_dict["target1"]
epe = _elementwise_epe(output, target)
loss_dict["epe"] = epe.mean()
return loss_dict
class MultiScaleEPE_FlowNet_IRR(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
total_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj, target_f_ii)
total_loss = total_loss + self._weights[ii] * epe_f_ii.sum()
loss_dict["epe%i" % (ii + 2)] = epe_f_ii.mean()
loss_dict["total_loss"] = total_loss / self._batch_size / self._num_iters
else:
output = output_dict["flow1"]
target_f = target_dict["target1"]
epe_f = _elementwise_epe(target_f, output)
loss_dict["epe"] = epe_f.mean()
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Bi(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Bi, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
target_b = self._args.model_div_flow * target_dict["target2"]
total_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0][0])
target_b_ii = _downsample2d_as(target_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)
epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)
total_loss = total_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum())
loss_dict["epe%i" % (ii + 2)] = (epe_f_ii.mean() + epe_b_ii.mean()) / 2
loss_dict["total_loss"] = total_loss / self._batch_size / self._num_iters / 2
else:
epe_f = _elementwise_epe(output_dict["flow1"], target_dict["target1"])
loss_dict["epe"] = epe_f.mean()
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Occ(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Occ, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
self.f1_score_bal_loss = f1_score_bal_loss
self.occ_activ = nn.Sigmoid()
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
outputs_occ = [output_dict[key] for key in ["occ2", "occ3", "occ4", "occ5", "occ6"]]
# div_flow trick
target = self._args.model_div_flow * target_dict["target1"]
target_occ = target_dict["target_occ1"]
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_ii = _downsample2d_as(target, output_ii[0])
for jj, output_ii_jj in enumerate(output_ii):
flow_loss = flow_loss + self._weights[ii] * _elementwise_epe(output_ii_jj, target_ii).sum()
for ii, output_ii in enumerate(outputs_occ):
target_occ_f = _downsample2d_as(target_occ, output_ii[0])
for jj, output_ii_jj in enumerate(output_ii):
occ_loss = occ_loss + self._weights[ii] * self.f1_score_bal_loss(self.occ_activ(output_ii_jj), target_occ_f)
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size / self._num_iters
loss_dict["occ_loss"] = occ_loss / self._batch_size / self._num_iters
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / self._num_iters
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow1"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ1"])))
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Bi_Occ(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Bi_Occ, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]
self._num_iters = args.num_iters
self.f1_score_bal_loss = f1_score_bal_loss
self.occ_activ = nn.Sigmoid()
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow2", "flow3", "flow4", "flow5", "flow6"]]
outputs_occ = [output_dict[key] for key in ["occ2", "occ3", "occ4", "occ5", "occ6"]]
# div_flow trick
target_f = self._args.model_div_flow * target_dict["target1"]
target_b = self._args.model_div_flow * target_dict["target2"]
target_occ_f = target_dict["target_occ1"]
target_occ_b = target_dict["target_occ2"]
flow_loss = 0
occ_loss = 0
for ii, output_ii in enumerate(outputs_flo):
target_f_ii = _downsample2d_as(target_f, output_ii[0][0])
target_b_ii = _downsample2d_as(target_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)
epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)
flow_loss = flow_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum()) * 0.5
for ii, output_ii in enumerate(outputs_occ):
target_occ_f = _downsample2d_as(target_occ_f, output_ii[0][0])
target_occ_b = _downsample2d_as(target_occ_b, output_ii[0][1])
for jj, output_ii_jj in enumerate(output_ii):
output_occ_f = self.occ_activ(output_ii_jj[0])
output_occ_b = self.occ_activ(output_ii_jj[1])
bce_f_ii = self.f1_score_bal_loss(output_occ_f, target_occ_f)
bce_b_ii = self.f1_score_bal_loss(output_occ_b, target_occ_b)
occ_loss = occ_loss + self._weights[ii] * (bce_f_ii + bce_b_ii) * 0.5
f_loss = flow_loss.detach()
o_loss = occ_loss.detach()
if f_loss > o_loss:
f_l_w = 1
o_l_w = f_loss / o_loss
else:
f_l_w = o_loss / f_loss
o_l_w = 1
loss_dict["flow_loss"] = flow_loss / self._batch_size / self._num_iters
loss_dict["occ_loss"] = occ_loss / self._batch_size / self._num_iters
loss_dict["total_loss"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / self._num_iters
else:
loss_dict["epe"] = _elementwise_epe(output_dict["flow1"], target_dict["target1"]).mean()
loss_dict["F1"] = f1_score(target_dict["target_occ1"], torch.round(self.occ_activ(output_dict["occ1"])))
return loss_dict
class MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample(nn.Module):
def __init__(self,
args):
super(MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample, self).__init__()
self._args = args
self._batch_size = args.batch_size
self._weights = [0.0003125, 0.00125, 0.005, 0.01, 0.02, 0.08, 0.32]
self.occ_activ = nn.Sigmoid()
self.f1_score_bal_loss = f1_score_bal_loss
def forward(self, output_dict, target_dict):
loss_dict = {}
if self.training:
outputs_flo = [output_dict[key] for key in ["flow", "flow1", "flow2",
gitextract_pwdsdpvy/
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── augmentations.py
├── commandline.py
├── configuration.py
├── datasets/
│ ├── __init__.py
│ ├── common.py
│ ├── flyingThings3D.py
│ ├── flyingchairs.py
│ ├── flyingchairsOcc.py
│ ├── kitti_combined.py
│ ├── sintel.py
│ └── transforms.py
├── flyingchairsocc/
│ └── README.md
├── install.sh
├── logger.py
├── losses.py
├── main.py
├── models/
│ ├── IRR_FlowNet.py
│ ├── IRR_PWC.py
│ ├── __init__.py
│ ├── correlation_package/
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── correlation_package_cu9/
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── flownet1s.py
│ ├── flownet1s_irr.py
│ ├── flownet1s_irr_bi.py
│ ├── flownet1s_irr_occ.py
│ ├── flownet1s_irr_occ_bi.py
│ ├── flownet_modules.py
│ ├── irr_modules.py
│ ├── pwc_modules.py
│ ├── pwcnet.py
│ ├── pwcnet_bi.py
│ ├── pwcnet_irr.py
│ ├── pwcnet_irr_bi.py
│ ├── pwcnet_irr_occ.py
│ ├── pwcnet_irr_occ_bi.py
│ ├── pwcnet_occ.py
│ └── pwcnet_occ_bi.py
├── optim/
│ └── __init__.py
├── runtime.py
├── saved_check_point/
│ └── pwcnet/
│ ├── IRR-PWC_flyingchairsOcc/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_kitti/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_sintel/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── IRR-PWC_things3d/
│ │ ├── checkpoint_best.ckpt
│ │ └── checkpoint_latest.ckpt
│ ├── PWCNet/
│ │ └── checkpoint_best.ckpt
│ └── PWCNet-irr/
│ └── checkpoint_best.ckpt
├── scripts/
│ ├── IRR-FlowNet_flyingChairsOcc.sh
│ ├── IRR-PWC_flyingChairsOcc.sh
│ ├── IRR-PWC_kitti_train.sh
│ ├── IRR-PWC_kitti_train_full.sh
│ ├── IRR-PWC_sintel_train.sh
│ ├── IRR-PWC_sintel_train_full.sh
│ ├── IRR-PWC_things3d.sh
│ ├── flownet1s.sh
│ ├── flownet1s_irr1.sh
│ ├── flownet1s_irr2.sh
│ ├── pwcnet.sh
│ ├── pwcnet_irr.sh
│ └── validation/
│ ├── IRR-FlowNet_flyingChairs.sh
│ ├── IRR-PWC_flyingChairs.sh
│ ├── IRR-PWC_kitti.sh
│ ├── IRR-PWC_sintel.sh
│ ├── IRR-PWC_things3d.sh
│ ├── flownet1s.sh
│ ├── flownet1s_irr1.sh
│ ├── flownet1s_irr2.sh
│ ├── pwcnet.sh
│ └── pwcnet_irr.sh
├── tools.py
└── utils/
├── __init__.py
├── flow.py
└── interpolation.py
SYMBOL INDEX (472 symbols across 39 files)
FILE: augmentations.py
function denormalize_coords (line 12) | def denormalize_coords(xx, yy, width, height):
function normalize_coords (line 19) | def normalize_coords(xx, yy, width, height):
function apply_transform_to_params (line 26) | def apply_transform_to_params(theta0, theta_transform):
class _IdentityParams (line 52) | class _IdentityParams(nn.Module):
method __init__ (line 53) | def __init__(self):
method _update (line 59) | def _update(self, batch_size):
method forward (line 64) | def forward(self, batch_size):
class RandomMirror (line 71) | class RandomMirror(nn.Module):
method __init__ (line 72) | def __init__(self, vertical=True, p=0.5):
method update_probs (line 79) | def update_probs(self, batch_size):
method forward (line 83) | def forward(self, theta1, theta2):
class RandomCrop (line 106) | class RandomCrop(nn.Module):
method __init__ (line 112) | def __init__(self, crop):
method forward (line 118) | def forward(self, im1, im2, flo):
class RandomAffineFlow (line 135) | class RandomAffineFlow(nn.Module):
method __init__ (line 136) | def __init__(self, args, addnoise=True):
method inverse_transform_coords (line 150) | def inverse_transform_coords(self, width, height, thetas, offset_x=Non...
method transform_coords (line 174) | def transform_coords(self, width, height, thetas):
method find_invalid (line 202) | def find_invalid(self, width, height, thetas):
method apply_random_transforms_to_params (line 231) | def apply_random_transforms_to_params(self,
method transform_image (line 282) | def transform_image(self, images, thetas):
method transform_flow (line 288) | def transform_flow(self, flow, theta1, theta2):
method forward (line 312) | def forward(self, example_dict):
class RandomAffineFlowOcc (line 368) | class RandomAffineFlowOcc(nn.Module):
method __init__ (line 369) | def __init__(self, args, addnoise=True, crop=None):
method inverse_transform_coords (line 387) | def inverse_transform_coords(self, width, height, thetas, offset_x=Non...
method transform_coords (line 411) | def transform_coords(self, width, height, thetas):
method find_invalid (line 439) | def find_invalid(self, width, height, thetas):
method apply_random_transforms_to_params (line 468) | def apply_random_transforms_to_params(self,
method transform_image (line 519) | def transform_image(self, images, thetas):
method transform_flow (line 525) | def transform_flow(self, flow, theta1, theta2):
method check_out_of_bound (line 549) | def check_out_of_bound(self, flow, occ, batch_size):
method random_crop (line 564) | def random_crop(self, im1, im2, flo_f, flo_b, occ1, occ2):
method forward (line 586) | def forward(self, example_dict):
class RandomAffineFlowOccSintel (line 656) | class RandomAffineFlowOccSintel(nn.Module):
method __init__ (line 657) | def __init__(self, args, addnoise=True, crop=None):
method inverse_transform_coords (line 675) | def inverse_transform_coords(self, width, height, thetas, offset_x=Non...
method transform_coords (line 699) | def transform_coords(self, width, height, thetas):
method find_invalid (line 727) | def find_invalid(self, width, height, thetas):
method apply_random_transforms_to_params (line 756) | def apply_random_transforms_to_params(self,
method transform_image (line 807) | def transform_image(self, images, thetas):
method transform_flow (line 813) | def transform_flow(self, flow, theta1, theta2):
method check_out_of_bound (line 837) | def check_out_of_bound(self, flow, occ, batch_size):
method random_crop (line 852) | def random_crop(self, im1, im2, flo_f, occ1):
method forward (line 872) | def forward(self, example_dict):
class RandomAffineFlowOccKITTI (line 935) | class RandomAffineFlowOccKITTI(nn.Module):
method __init__ (line 936) | def __init__(self, args, addnoise=True, crop=None):
method inverse_transform_coords (line 954) | def inverse_transform_coords(self, width, height, thetas, offset_x=Non...
method transform_coords (line 978) | def transform_coords(self, width, height, thetas):
method find_invalid (line 1006) | def find_invalid(self, width, height, thetas):
method apply_random_transforms_to_params (line 1035) | def apply_random_transforms_to_params(self,
method transform_image (line 1086) | def transform_image(self, images, thetas):
method transform_flow (line 1092) | def transform_flow(self, flow, theta1, theta2, valid_mask):
method check_out_of_bound (line 1117) | def check_out_of_bound(self, flow, occ, batch_size):
method random_crop (line 1132) | def random_crop(self, im1, im2, flo_f, valid_mask):
method forward (line 1152) | def forward(self, example_dict):
FILE: commandline.py
function _get_type_from_arg (line 22) | def _get_type_from_arg(arg):
function _add_arguments_for_module (line 29) | def _add_arguments_for_module(parser,
function _add_special_arguments (line 138) | def _add_special_arguments(parser):
function _parse_arguments (line 179) | def _parse_arguments():
function postprocess_args (line 341) | def postprocess_args(args):
function setup_logging_and_parse_arguments (line 383) | def setup_logging_and_parse_arguments(blocktitle):
FILE: configuration.py
class ModelAndLoss (line 20) | class ModelAndLoss(nn.Module):
method __init__ (line 21) | def __init__(self, args, model, training_loss, evaluation_loss=None):
method training_loss (line 28) | def training_loss(self):
method evaluation_loss (line 32) | def evaluation_loss(self):
method model (line 36) | def model(self):
method num_parameters (line 39) | def num_parameters(self):
method forward (line 45) | def forward(self, example_dict):
function configure_runtime_augmentations (line 65) | def configure_runtime_augmentations(args):
function configure_model_and_loss (line 108) | def configure_model_and_loss(args):
function configure_random_seed (line 169) | def configure_random_seed(args):
class CheckpointSaver (line 192) | class CheckpointSaver:
method __init__ (line 193) | def __init__(self,
method _load_state_dict_into_module (line 211) | def _load_state_dict_into_module(self, state_dict, module, strict=True):
method restore (line 235) | def restore(self, filename, model_and_loss, include_params="*", exclud...
method restore_latest (line 271) | def restore_latest(self, directory, model_and_loss, include_params="*"...
method restore_best (line 276) | def restore_best(self, directory, model_and_loss, include_params="*", ...
method save_latest (line 281) | def save_latest(self, directory, model_and_loss, stats_dict, store_as_...
function configure_checkpoint_saver (line 317) | def configure_checkpoint_saver(args, model_and_loss):
function configure_data_loaders (line 362) | def configure_data_loaders(args):
function _print_trainable_params (line 456) | def _print_trainable_params(model_and_loss, match="*"):
function _generate_trainable_params (line 468) | def _generate_trainable_params(model_and_loss, match="*"):
function _param_names_and_trainable_generator (line 475) | def _param_names_and_trainable_generator(model_and_loss, match="*"):
function configure_optimizer (line 488) | def configure_optimizer(args, model_and_loss):
function configure_lr_scheduler (line 579) | def configure_lr_scheduler(args, optimizer):
FILE: datasets/common.py
function numpy2torch (line 10) | def numpy2torch(array):
function read_flo_as_float32 (line 19) | def read_flo_as_float32(filename):
function read_occ_image_as_float32 (line 30) | def read_occ_image_as_float32(filename):
function read_image_as_float32 (line 37) | def read_image_as_float32(filename):
function read_image_as_byte (line 41) | def read_image_as_byte(filename):
FILE: datasets/flyingThings3D.py
function fillingInNaN (line 15) | def fillingInNaN(flow):
class FlyingThings3d (line 41) | class FlyingThings3d(data.Dataset):
method __init__ (line 42) | def __init__(self,
method __getitem__ (line 135) | def __getitem__(self, index):
method __len__ (line 187) | def __len__(self):
class FlyingThings3dFinalTrain (line 191) | class FlyingThings3dFinalTrain(FlyingThings3d):
method __init__ (line 192) | def __init__(self,
class FlyingThings3dFinalTest (line 207) | class FlyingThings3dFinalTest(FlyingThings3d):
method __init__ (line 208) | def __init__(self,
class FlyingThings3dCleanTrain (line 223) | class FlyingThings3dCleanTrain(FlyingThings3d):
method __init__ (line 224) | def __init__(self,
class FlyingThings3dCleanTest (line 239) | class FlyingThings3dCleanTest(FlyingThings3d):
method __init__ (line 240) | def __init__(self,
FILE: datasets/flyingchairs.py
class FlyingChairs (line 81) | class FlyingChairs(data.Dataset):
method __init__ (line 82) | def __init__(self,
method __getitem__ (line 150) | def __getitem__(self, index):
method __len__ (line 185) | def __len__(self):
class FlyingChairsTrain (line 189) | class FlyingChairsTrain(FlyingChairs):
method __init__ (line 190) | def __init__(self,
class FlyingChairsValid (line 201) | class FlyingChairsValid(FlyingChairs):
method __init__ (line 202) | def __init__(self,
class FlyingChairsFull (line 213) | class FlyingChairsFull(FlyingChairs):
method __init__ (line 214) | def __init__(self,
FILE: datasets/flyingchairsOcc.py
class FlyingChairsOcc (line 81) | class FlyingChairsOcc(data.Dataset):
method __init__ (line 82) | def __init__(self,
method __getitem__ (line 165) | def __getitem__(self, index):
method __len__ (line 210) | def __len__(self):
class FlyingChairsOccTrain (line 214) | class FlyingChairsOccTrain(FlyingChairsOcc):
method __init__ (line 215) | def __init__(self,
class FlyingChairsOccValid (line 226) | class FlyingChairsOccValid(FlyingChairsOcc):
method __init__ (line 227) | def __init__(self,
class FlyingChairsOccFull (line 238) | class FlyingChairsOccFull(FlyingChairsOcc):
method __init__ (line 239) | def __init__(self,
FILE: datasets/kitti_combined.py
function read_png_flow (line 19) | def read_png_flow(flow_file):
function kitti_random_crop (line 37) | def kitti_random_crop(im1, im2, flo_f, valid_mask, crop_height=370, crop...
class Kitti_comb_test (line 55) | class Kitti_comb_test(data.Dataset):
method __init__ (line 56) | def __init__(self,
method __getitem__ (line 156) | def __getitem__(self, index):
method __len__ (line 181) | def __len__(self):
class Kitti_comb (line 185) | class Kitti_comb(data.Dataset):
method __init__ (line 186) | def __init__(self,
method __getitem__ (line 337) | def __getitem__(self, index):
method __len__ (line 374) | def __len__(self):
class KittiCombTrain (line 378) | class KittiCombTrain(Kitti_comb):
method __init__ (line 379) | def __init__(self,
class KittiCombVal (line 399) | class KittiCombVal(Kitti_comb):
method __init__ (line 400) | def __init__(self,
class KittiCombFull (line 420) | class KittiCombFull(Kitti_comb):
method __init__ (line 421) | def __init__(self,
class KittiComb2015Train (line 441) | class KittiComb2015Train(Kitti_comb):
method __init__ (line 442) | def __init__(self,
class KittiComb2015Val (line 458) | class KittiComb2015Val(Kitti_comb):
method __init__ (line 459) | def __init__(self,
class KittiComb2015Full (line 475) | class KittiComb2015Full(Kitti_comb):
method __init__ (line 476) | def __init__(self,
class KittiComb2015Test (line 492) | class KittiComb2015Test(Kitti_comb_test):
method __init__ (line 493) | def __init__(self,
class KittiComb2012Train (line 506) | class KittiComb2012Train(Kitti_comb):
method __init__ (line 507) | def __init__(self,
class KittiComb2012Val (line 523) | class KittiComb2012Val(Kitti_comb):
method __init__ (line 524) | def __init__(self,
class KittiComb2012Full (line 540) | class KittiComb2012Full(Kitti_comb):
method __init__ (line 541) | def __init__(self,
class KittiComb2012Test (line 557) | class KittiComb2012Test(Kitti_comb_test):
method __init__ (line 558) | def __init__(self,
FILE: datasets/sintel.py
class _Sintel (line 30) | class _Sintel(data.Dataset):
method __init__ (line 31) | def __init__(self,
method __getitem__ (line 168) | def __getitem__(self, index):
method __len__ (line 205) | def __len__(self):
class _Sintel_test (line 209) | class _Sintel_test(data.Dataset):
method __init__ (line 210) | def __init__(self,
method __getitem__ (line 285) | def __getitem__(self, index):
method __len__ (line 314) | def __len__(self):
class SintelTrainingCleanTrain (line 318) | class SintelTrainingCleanTrain(_Sintel):
method __init__ (line 319) | def __init__(self, args, root, photometric_augmentations=True):
class SintelTrainingCleanValid (line 329) | class SintelTrainingCleanValid(_Sintel):
method __init__ (line 330) | def __init__(self, args, root, photometric_augmentations=False):
class SintelTrainingCleanFull (line 340) | class SintelTrainingCleanFull(_Sintel):
method __init__ (line 341) | def __init__(self, args, root, photometric_augmentations=True):
class SintelTrainingFinalTrain (line 351) | class SintelTrainingFinalTrain(_Sintel):
method __init__ (line 352) | def __init__(self, args, root, photometric_augmentations=True):
class SintelTrainingFinalValid (line 362) | class SintelTrainingFinalValid(_Sintel):
method __init__ (line 363) | def __init__(self, args, root, photometric_augmentations=False):
class SintelTrainingFinalFull (line 373) | class SintelTrainingFinalFull(_Sintel):
method __init__ (line 374) | def __init__(self, args, root, photometric_augmentations=True):
class SintelTrainingCombTrain (line 384) | class SintelTrainingCombTrain(_Sintel):
method __init__ (line 385) | def __init__(self, args, root, photometric_augmentations=True):
class SintelTrainingCombValid (line 395) | class SintelTrainingCombValid(_Sintel):
method __init__ (line 396) | def __init__(self, args, root, photometric_augmentations=False):
class SintelTrainingCombFull (line 406) | class SintelTrainingCombFull(_Sintel):
method __init__ (line 407) | def __init__(self, args, root, photometric_augmentations=True):
class SintelTestClean (line 417) | class SintelTestClean(_Sintel_test):
method __init__ (line 418) | def __init__(self, args, root, photometric_augmentations=False):
class SintelTestFinal (line 427) | class SintelTestFinal(_Sintel_test):
method __init__ (line 428) | def __init__(self, args, root, photometric_augmentations=False):
FILE: datasets/transforms.py
function image_random_gamma (line 9) | def image_random_gamma(image, min_gamma=0.7, max_gamma=1.5, clip_image=F...
class RandomGamma (line 17) | class RandomGamma:
method __init__ (line 18) | def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False):
method __call__ (line 23) | def __call__(self, image):
class TransformChainer (line 35) | class TransformChainer:
method __init__ (line 36) | def __init__(self, list_of_transforms):
method __call__ (line 39) | def __call__(self, *args):
class ConcatTransformSplitChainer (line 53) | class ConcatTransformSplitChainer:
method __init__ (line 54) | def __init__(self, list_of_transforms, from_numpy=True, to_numpy=False):
method __call__ (line 59) | def __call__(self, *args):
FILE: logger.py
function get_default_logging_format (line 13) | def get_default_logging_format(colorize=False, brackets=False):
function get_default_logging_datefmt (line 25) | def get_default_logging_datefmt():
function log_module_info (line 29) | def log_module_info(module):
class LogbookFormatter (line 35) | class LogbookFormatter(logging.Formatter):
method __init__ (line 36) | def __init__(self, fmt=None, datefmt=None):
method remove_colors_from_msg (line 40) | def remove_colors_from_msg(self, msg):
method format (line 44) | def format(self, record=None):
class ConsoleFormatter (line 49) | class ConsoleFormatter(logging.Formatter):
method __init__ (line 50) | def __init__(self, fmt=None, datefmt=None):
method format (line 53) | def format(self, record=None):
class SkipLogbookFilter (line 59) | class SkipLogbookFilter(logging.Filter):
method filter (line 60) | def filter(self, record):
function configure_logging (line 64) | def configure_logging(filename=None):
class LoggingBlock (line 108) | class LoggingBlock:
method __init__ (line 109) | def __init__(self, title, emph=False):
method __enter__ (line 119) | def __enter__(self):
method __exit__ (line 123) | def __exit__(self, exc_type, exc_value, traceback):
FILE: losses.py
function _elementwise_epe (line 8) | def _elementwise_epe(input_flow, target_flow):
function _elementwise_robust_epe_char (line 12) | def _elementwise_robust_epe_char(input_flow, target_flow):
function _downsample2d_as (line 16) | def _downsample2d_as(inputs, target_as):
function _upsample2d_as (line 20) | def _upsample2d_as(inputs, target_as, mode="bilinear"):
function f1_score (line 24) | def f1_score(y_true, y_pred):
function fbeta_score (line 27) | def fbeta_score(y_true, y_pred, beta, eps=1e-8):
function f1_score_bal_loss (line 39) | def f1_score_bal_loss(y_pred, y_true):
class MultiScaleEPE_FlowNet (line 51) | class MultiScaleEPE_FlowNet(nn.Module):
method __init__ (line 52) | def __init__(self,
method forward (line 60) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_FlowNet_IRR (line 84) | class MultiScaleEPE_FlowNet_IRR(nn.Module):
method __init__ (line 85) | def __init__(self,
method forward (line 94) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_FlowNet_IRR_Bi (line 120) | class MultiScaleEPE_FlowNet_IRR_Bi(nn.Module):
method __init__ (line 121) | def __init__(self,
method forward (line 130) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_FlowNet_IRR_Occ (line 157) | class MultiScaleEPE_FlowNet_IRR_Occ(nn.Module):
method __init__ (line 158) | def __init__(self,
method forward (line 170) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_FlowNet_IRR_Bi_Occ (line 213) | class MultiScaleEPE_FlowNet_IRR_Bi_Occ(nn.Module):
method __init__ (line 214) | def __init__(self,
method forward (line 226) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample (line 278) | class MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample(nn.Module):
method __init__ (line 279) | def __init__(self,
method forward (line 289) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC (line 344) | class MultiScaleEPE_PWC(nn.Module):
method __init__ (line 345) | def __init__(self,
method forward (line 353) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC_Bi (line 374) | class MultiScaleEPE_PWC_Bi(nn.Module):
method __init__ (line 375) | def __init__(self,
method forward (line 383) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC_Occ (line 405) | class MultiScaleEPE_PWC_Occ(nn.Module):
method __init__ (line 406) | def __init__(self,
method forward (line 417) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC_Bi_Occ (line 457) | class MultiScaleEPE_PWC_Bi_Occ(nn.Module):
method __init__ (line 458) | def __init__(self,
method forward (line 469) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC_Bi_Occ_upsample (line 515) | class MultiScaleEPE_PWC_Bi_Occ_upsample(nn.Module):
method __init__ (line 516) | def __init__(self,
method forward (line 527) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel (line 579) | class MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel(nn.Module):
method __init__ (line 580) | def __init__(self,
method forward (line 591) | def forward(self, output_dict, target_dict):
class MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI (line 640) | class MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI(nn.Module):
method __init__ (line 641) | def __init__(self,
method forward (line 651) | def forward(self, output_dict, target_dict):
FILE: main.py
function main (line 14) | def main():
FILE: models/IRR_FlowNet.py
class FlowNetS (line 11) | class FlowNetS(nn.Module):
method __init__ (line 12) | def __init__(self, args):
method forward (line 72) | def forward(self, conv2_im1, conv3_im1, conv3_im2):
class FlowNet1S (line 130) | class FlowNet1S(nn.Module):
method __init__ (line 131) | def __init__(self, args, div_flow=0.05):
method forward (line 153) | def forward(self, input_dict):
FILE: models/IRR_PWC.py
class PWCNet (line 14) | class PWCNet(nn.Module):
method __init__ (line 15) | def __init__(self, args, div_flow=0.05):
method forward (line 51) | def forward(self, input_dict):
FILE: models/correlation_package/correlation.py
class CorrelationFunction (line 6) | class CorrelationFunction(Function):
method __init__ (line 8) | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, str...
method forward (line 18) | def forward(self, input1, input2):
method backward (line 31) | def backward(self, grad_output):
class Correlation (line 47) | class Correlation(Module):
method __init__ (line 48) | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stri...
method forward (line 57) | def forward(self, input1, input2):
FILE: models/correlation_package/correlation_cuda.cc
function correlation_forward_cuda (line 8) | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at:...
function correlation_backward_cuda (line 86) | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at...
function PYBIND11_MODULE (line 165) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: models/correlation_package_cu9/correlation.py
class CorrelationFunction (line 6) | class CorrelationFunction(Function):
method __init__ (line 8) | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, str...
method forward (line 18) | def forward(self, input1, input2):
method backward (line 31) | def backward(self, grad_output):
class Correlation (line 47) | class Correlation(Module):
method __init__ (line 48) | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stri...
method forward (line 57) | def forward(self, input1, input2):
FILE: models/correlation_package_cu9/correlation_cuda.cc
function correlation_forward_cuda (line 10) | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at:...
function correlation_backward_cuda (line 89) | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at...
function PYBIND11_MODULE (line 169) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: models/flownet1s.py
class FlowNetS (line 10) | class FlowNetS(nn.Module):
method __init__ (line 11) | def __init__(self, args):
method forward (line 60) | def forward(self, inputs):
class FlowNet1S (line 96) | class FlowNet1S(nn.Module):
method __init__ (line 97) | def __init__(self, args, div_flow=0.05):
method forward (line 102) | def forward(self, input_dict):
FILE: models/flownet1s_irr.py
class FlowNetS (line 10) | class FlowNetS(nn.Module):
method __init__ (line 11) | def __init__(self, args):
method forward (line 55) | def forward(self, conv2_im1, conv3_im1, conv3_im2):
class FlowNet1S (line 89) | class FlowNet1S(nn.Module):
method __init__ (line 90) | def __init__(self, args, div_flow=0.05):
method forward (line 108) | def forward(self, input_dict):
FILE: models/flownet1s_irr_bi.py
class FlowNetS (line 10) | class FlowNetS(nn.Module):
method __init__ (line 11) | def __init__(self, args):
method forward (line 55) | def forward(self, conv2_im1, conv3_im1, conv3_im2):
class FlowNet1S (line 90) | class FlowNet1S(nn.Module):
method __init__ (line 91) | def __init__(self, args, div_flow=0.05):
method forward (line 109) | def forward(self, input_dict):
FILE: models/flownet1s_irr_occ.py
class FlowNetS (line 10) | class FlowNetS(nn.Module):
method __init__ (line 11) | def __init__(self, args):
method forward (line 71) | def forward(self, conv2_im1, conv3_im1, conv3_im2):
class FlowNet1S (line 129) | class FlowNet1S(nn.Module):
method __init__ (line 130) | def __init__(self, args, div_flow=0.05):
method forward (line 148) | def forward(self, input_dict):
FILE: models/flownet1s_irr_occ_bi.py
class FlowNetS (line 10) | class FlowNetS(nn.Module):
method __init__ (line 11) | def __init__(self, args):
method forward (line 71) | def forward(self, conv2_im1, conv3_im1, conv3_im2):
class FlowNet1S (line 129) | class FlowNet1S(nn.Module):
method __init__ (line 130) | def __init__(self, args, div_flow=0.05):
method forward (line 148) | def forward(self, input_dict):
FILE: models/flownet_modules.py
function conv (line 9) | def conv(in_planes, out_planes, kernel_size, stride, pad, nonlinear, bias):
function deconv (line 22) | def deconv(in_planes, out_planes, kernel_size, stride, pad, nonlinear, b...
function resize2D (line 35) | def resize2D(inputs, size_targets, mode="bilinear"):
function resize2D_as (line 47) | def resize2D_as(inputs, output_as, mode="bilinear"):
function concatenate_as (line 52) | def concatenate_as(tensor_list, tensor_as, dim, mode="bilinear"):
function upsample2d_as (line 57) | def upsample2d_as(inputs, target_as, mode="bilinear"):
function initialize_msra (line 62) | def initialize_msra(modules):
function get_grid (line 85) | def get_grid(x):
class WarpingLayer (line 93) | class WarpingLayer(nn.Module):
method __init__ (line 94) | def __init__(self):
method forward (line 97) | def forward(self, x, flow, height_im, width_im, div_flow):
FILE: models/irr_modules.py
function conv (line 7) | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isR...
function upsample_factor2 (line 21) | def upsample_factor2(inputs, target_as):
class OccUpsampleNetwork (line 30) | class OccUpsampleNetwork(nn.Module):
method __init__ (line 31) | def __init__(self, ch_in, ch_out):
method forward (line 46) | def forward(self, occ, x):
function subtract_mean (line 59) | def subtract_mean(input):
class RefineFlow (line 63) | class RefineFlow(nn.Module):
method __init__ (line 64) | def __init__(self, ch_in):
method forward (line 85) | def forward(self, flow, diff_img, feature):
class RefineOcc (line 107) | class RefineOcc(nn.Module):
method __init__ (line 108) | def __init__(self, ch_in):
method forward (line 129) | def forward(self, occ, feat1, feat2):
FILE: models/pwc_modules.py
function conv (line 8) | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isR...
function initialize_msra (line 22) | def initialize_msra(modules):
function compute_cost_volume (line 42) | def compute_cost_volume(feat1, feat2, param_dict):
function upsample2d_as (line 65) | def upsample2d_as(inputs, target_as, mode="bilinear"):
function rescale_flow (line 70) | def rescale_flow(flow, div_flow, width_im, height_im, to_local=True):
class FeatureExtractor (line 85) | class FeatureExtractor(nn.Module):
method __init__ (line 86) | def __init__(self, num_chs):
method forward (line 98) | def forward(self, x):
function get_grid (line 107) | def get_grid(x):
class WarpingLayer (line 115) | class WarpingLayer(nn.Module):
method __init__ (line 116) | def __init__(self):
method forward (line 119) | def forward(self, x, flow, height_im, width_im, div_flow):
class OpticalFlowEstimator (line 135) | class OpticalFlowEstimator(nn.Module):
method __init__ (line 136) | def __init__(self, ch_in):
method forward (line 148) | def forward(self, x):
class FlowEstimatorDense (line 153) | class FlowEstimatorDense(nn.Module):
method __init__ (line 154) | def __init__(self, ch_in):
method forward (line 163) | def forward(self, x):
class OcclusionEstimator (line 173) | class OcclusionEstimator(nn.Module):
method __init__ (line 174) | def __init__(self, ch_in):
method forward (line 185) | def forward(self, x):
class OccEstimatorDense (line 190) | class OccEstimatorDense(nn.Module):
method __init__ (line 191) | def __init__(self, ch_in):
method forward (line 200) | def forward(self, x):
class ContextNetwork (line 210) | class ContextNetwork(nn.Module):
method __init__ (line 211) | def __init__(self, ch_in):
method forward (line 224) | def forward(self, x):
class OccContextNetwork (line 228) | class OccContextNetwork(nn.Module):
method __init__ (line 229) | def __init__(self, ch_in):
method forward (line 242) | def forward(self, x):
FILE: models/pwcnet.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 43) | def forward(self, input_dict):
FILE: models/pwcnet_bi.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 42) | def forward(self, input_dict):
FILE: models/pwcnet_irr.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 40) | def forward(self, input_dict):
FILE: models/pwcnet_irr_bi.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 40) | def forward(self, input_dict):
FILE: models/pwcnet_irr_occ.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 43) | def forward(self, input_dict):
FILE: models/pwcnet_irr_occ_bi.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 43) | def forward(self, input_dict):
FILE: models/pwcnet_occ.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 49) | def forward(self, input_dict):
FILE: models/pwcnet_occ_bi.py
class PWCNet (line 9) | class PWCNet(nn.Module):
method __init__ (line 10) | def __init__(self, args, div_flow=0.05):
method forward (line 49) | def forward(self, input_dict):
FILE: runtime.py
function create_progressbar (line 32) | def create_progressbar(iterable,
function tensor2float_dict (line 93) | def tensor2float_dict(tensor_dict):
function format_moving_averages_as_progress_dict (line 97) | def format_moving_averages_as_progress_dict(moving_averages_dict={},
function format_learning_rate (line 106) | def format_learning_rate(lr):
class TrainingEpoch (line 113) | class TrainingEpoch:
method __init__ (line 114) | def __init__(self,
method _step (line 131) | def _step(self, example_dict):
method run (line 196) | def run(self, offset=0):
class EvaluationEpoch (line 258) | class EvaluationEpoch:
method __init__ (line 259) | def __init__(self,
method save_outputs (line 276) | def save_outputs(self, example_dict, output_dict):
method _step (line 354) | def _step(self, example_dict):
method run (line 391) | def run(self, offset=0):
function exec_runtime (line 472) | def exec_runtime(args,
FILE: tools.py
function x2module (line 25) | def x2module(module_or_data_parallel):
function addLoggingLevel (line 37) | def addLoggingLevel(level_name, level_num, method_name=None):
function kwargs_from_args (line 67) | def kwargs_from_args(args, name, exclude=[]):
function instance_from_kwargs (line 84) | def instance_from_kwargs(class_constructor, kwargs):
function module_classes_to_dict (line 92) | def module_classes_to_dict(module, include_classes="*", exclude_classes=...
function ensure_dir (line 127) | def ensure_dir(file_path):
function search_and_replace (line 133) | def search_and_replace(string, regex, replace):
function hostname (line 143) | def hostname():
function get_filenames (line 151) | def get_filenames(directory, match='*.*', not_match=()):
function str2bool (line 170) | def str2bool(v):
function str2str_or_none (line 179) | def str2str_or_none(v):
function str2dict (line 185) | def str2dict(v):
function str2intlist (line 189) | def str2intlist(v):
function str2list (line 193) | def str2list(v):
function read_json (line 197) | def read_json(filename):
function write_json (line 224) | def write_json(data_dict, filename):
function datestr (line 229) | def datestr():
function filter_list_of_strings (line 235) | def filter_list_of_strings(lst, include="*", exclude=()):
function write_dictionary_to_file (line 246) | def write_dictionary_to_file(arguments_dict, filename):
class MovingAverage (line 284) | class MovingAverage:
method __init__ (line 287) | def __init__(self):
method add_value (line 291) | def add_value(self, sigma, addcount=1):
method add_average (line 295) | def add_average(self, avg, addcount):
method mean (line 299) | def mean(self):
class ExponentialMovingAverage (line 303) | class ExponentialMovingAverage:
method __init__ (line 306) | def __init__(self, alpha=0.7):
method add_value (line 311) | def add_value(self, sigma, addcount=1):
method add_average (line 315) | def add_average(self, avg, addcount):
method mean (line 319) | def mean(self):
class TqdmToLogger (line 328) | class TqdmToLogger(tqdm.tqdm):
method __init__ (line 329) | def __init__(self, iterable=None, desc=None, total=None, leave=True,
method format_meter (line 351) | def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False,
method update (line 367) | def update(self, n=1):
method close (line 373) | def close(self):
function tqdm_with_logging (line 381) | def tqdm_with_logging(iterable=None, desc=None, total=None, leave=True,
function cd_dotdot (line 401) | def cd_dotdot(path_or_filename):
function cd_dotdotdot (line 405) | def cd_dotdotdot(path_or_filename):
function cd_dotdotdotdot (line 409) | def cd_dotdotdotdot(path_or_filename):
function tensor2numpy (line 413) | def tensor2numpy(tensor):
FILE: utils/flow.py
function write_flow (line 11) | def write_flow(filename, uv, v=None):
function write_flow_png (line 37) | def write_flow_png(filename, uv, v=None, mask=None):
function flow_to_png (line 65) | def flow_to_png(flow_map, max_value=None):
function compute_color (line 79) | def compute_color(u, v):
function make_color_wheel (line 123) | def make_color_wheel():
function flow_to_png_middlebury (line 173) | def flow_to_png_middlebury(flow):
FILE: utils/interpolation.py
function _bchw2bhwc (line 10) | def _bchw2bhwc(tensor):
function _bhwc2bchw (line 14) | def _bhwc2bchw(tensor):
class Meshgrid (line 18) | class Meshgrid(nn.Module):
method __init__ (line 19) | def __init__(self):
method _compute_meshgrid (line 26) | def _compute_meshgrid(self, width, height):
method forward (line 32) | def forward(self, width, height, device=None, dtype=None):
class BatchSub2Ind (line 42) | class BatchSub2Ind(nn.Module):
method __init__ (line 43) | def __init__(self):
method forward (line 47) | def forward(self, shape, row_sub, col_sub, out=None):
class Interp2 (line 60) | class Interp2(nn.Module):
method __init__ (line 61) | def __init__(self, clamp=False):
method forward (line 80) | def forward(self, v, xq, yq):
class Interp2MaskBinary (line 144) | class Interp2MaskBinary(nn.Module):
method __init__ (line 145) | def __init__(self, clamp=False):
method forward (line 168) | def forward(self, v, xq, yq, mask):
function resize2D (line 247) | def resize2D(inputs, size_targets, mode="bilinear"):
function resize2D_as (line 261) | def resize2D_as(inputs, output_as, mode="bilinear"):
Condensed preview — 89 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (495K chars).
[
{
"path": ".gitignore",
"chars": 14,
"preview": "*.pyc\n*.so\n*.o"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 3135,
"preview": "# Iterative Residual Refinement <br/> for Joint Optical Flow and Occlusion Estimation\n\n<img src=output.gif>\n\nThis reposi"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "augmentations.py",
"chars": 43244,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "commandline.py",
"chars": 18356,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "configuration.py",
"chars": 27481,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "datasets/__init__.py",
"chars": 2040,
"preview": "from . import flyingchairs\nfrom . import flyingchairsOcc\nfrom . import sintel\nfrom . import flyingThings3D\nfrom . import"
},
{
"path": "datasets/common.py",
"chars": 1204,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "datasets/flyingThings3D.py",
"chars": 9224,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob im"
},
{
"path": "datasets/flyingchairs.py",
"chars": 9689,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob im"
},
{
"path": "datasets/flyingchairsOcc.py",
"chars": 10965,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob im"
},
{
"path": "datasets/kitti_combined.py",
"chars": 22699,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob im"
},
{
"path": "datasets/sintel.py",
"chars": 17187,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob im"
},
{
"path": "datasets/transforms.py",
"chars": 2420,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "flyingchairsocc/README.md",
"chars": 2553,
"preview": "# FlyingChairsOcc dataset\n\n<img src=demo_img.png>\n\nThe FlyingChairsOcc dataset is an extended version of the <a href=\"ht"
},
{
"path": "install.sh",
"chars": 74,
"preview": "#!/bin/bash\ncd ./models/correlation_package\npython setup.py install\ncd ..\n"
},
{
"path": "logger.py",
"chars": 3990,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "losses.py",
"chars": 29114,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.fun"
},
{
"path": "main.py",
"chars": 2955,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport subprocess\nimport commandline\nimport "
},
{
"path": "models/IRR_FlowNet.py",
"chars": 14424,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nfrom .flownet_modul"
},
{
"path": "models/IRR_PWC.py",
"chars": 9002,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\n\nfrom .pwc_modules "
},
{
"path": "models/__init__.py",
"chars": 1084,
"preview": "from . import flownet1s\nfrom . import flownet1s_irr\nfrom . import flownet1s_irr_bi\nfrom . import flownet1s_irr_occ\nfrom "
},
{
"path": "models/correlation_package/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/correlation_package/correlation.py",
"chars": 2265,
"preview": "import torch\nfrom torch.nn.modules.module import Module\nfrom torch.autograd import Function\nimport correlation_cuda\n\ncla"
},
{
"path": "models/correlation_package/correlation_cuda.cc",
"chars": 6467,
"preview": "#include <torch/torch.h>\n#include <ATen/ATen.h>\n#include <stdio.h>\n#include <iostream>\n\n#include \"correlation_cuda_kerne"
},
{
"path": "models/correlation_package/correlation_cuda_kernel.cu",
"chars": 13882,
"preview": "#include <stdio.h>\n\n#include \"correlation_cuda_kernel.cuh\"\n\n#define CUDA_NUM_THREADS 1024\n#define THREADS_PER_BLOCK 32\n\n"
},
{
"path": "models/correlation_package/correlation_cuda_kernel.cuh",
"chars": 1409,
"preview": "#pragma once\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <cuda_runtime.h>\n\nint correlation_forward_cuda_k"
},
{
"path": "models/correlation_package/setup.py",
"chars": 745,
"preview": "#!/usr/bin/env python3\nimport os\nimport torch\n\nfrom setuptools import setup, find_packages\nfrom torch.utils.cpp_extensio"
},
{
"path": "models/correlation_package_cu9/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/correlation_package_cu9/correlation.py",
"chars": 2265,
"preview": "import torch\nfrom torch.nn.modules.module import Module\nfrom torch.autograd import Function\nimport correlation_cuda\n\ncla"
},
{
"path": "models/correlation_package_cu9/correlation_cuda.cc",
"chars": 6626,
"preview": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n#includ"
},
{
"path": "models/correlation_package_cu9/correlation_cuda_kernel.cu",
"chars": 13882,
"preview": "#include <stdio.h>\n\n#include \"correlation_cuda_kernel.cuh\"\n\n#define CUDA_NUM_THREADS 1024\n#define THREADS_PER_BLOCK 32\n\n"
},
{
"path": "models/correlation_package_cu9/correlation_cuda_kernel.cuh",
"chars": 1409,
"preview": "#pragma once\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <cuda_runtime.h>\n\nint correlation_forward_cuda_k"
},
{
"path": "models/correlation_package_cu9/setup.py",
"chars": 817,
"preview": "#!/usr/bin/env python3\nimport os\nimport torch\n\nfrom setuptools import setup, find_packages\nfrom torch.utils.cpp_extensio"
},
{
"path": "models/flownet1s.py",
"chars": 5176,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nfrom .flownet_modul"
},
{
"path": "models/flownet1s_irr.py",
"chars": 7097,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_m"
},
{
"path": "models/flownet1s_irr_bi.py",
"chars": 7740,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_m"
},
{
"path": "models/flownet1s_irr_occ.py",
"chars": 10213,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_m"
},
{
"path": "models/flownet1s_irr_occ_bi.py",
"chars": 11489,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_m"
},
{
"path": "models/flownet_modules.py",
"chars": 3808,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn"
},
{
"path": "models/irr_modules.py",
"chars": 5042,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn"
},
{
"path": "models/pwc_modules.py",
"chars": 7999,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn"
},
{
"path": "models/pwcnet.py",
"chars": 3633,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_bi.py",
"chars": 4465,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_irr.py",
"chars": 3888,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_irr_bi.py",
"chars": 4898,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_irr_occ.py",
"chars": 4717,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_irr_occ_bi.py",
"chars": 6267,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_occ.py",
"chars": 4645,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "models/pwcnet_occ_bi.py",
"chars": 6027,
"preview": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_mod"
},
{
"path": "optim/__init__.py",
"chars": 511,
"preview": "import torch\nimport sys\nfrom tools import module_classes_to_dict\n\n# ----------------------------------------------------"
},
{
"path": "runtime.py",
"chars": 26688,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "scripts/IRR-FlowNet_flyingChairsOcc.sh",
"chars": 1200,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/IRR-PWC_flyingChairsOcc.sh",
"chars": 1172,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/IRR-PWC_kitti_train.sh",
"chars": 1378,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nKITTI_HOME=(YOUR PATH)/KITTI_flo"
},
{
"path": "scripts/IRR-PWC_kitti_train_full.sh",
"chars": 1375,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nKITTI_HOME=(YOUR PATH)/KITTI_flo"
},
{
"path": "scripts/IRR-PWC_sintel_train.sh",
"chars": 2258,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sint"
},
{
"path": "scripts/IRR-PWC_sintel_train_full.sh",
"chars": 2252,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sint"
},
{
"path": "scripts/IRR-PWC_things3d.sh",
"chars": 1349,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGTHINGS_HOME=(YOUR PATH)/th"
},
{
"path": "scripts/flownet1s.sh",
"chars": 1181,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/flownet1s_irr1.sh",
"chars": 1189,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/flownet1s_irr2.sh",
"chars": 1186,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/pwcnet.sh",
"chars": 1158,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/pwcnet_irr.sh",
"chars": 1162,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH"
},
{
"path": "scripts/validation/IRR-FlowNet_flyingChairs.sh",
"chars": 814,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR "
},
{
"path": "scripts/validation/IRR-PWC_flyingChairs.sh",
"chars": 784,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR P"
},
{
"path": "scripts/validation/IRR-PWC_kitti.sh",
"chars": 760,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nKITTI_HOME=(YOUR PA"
},
{
"path": "scripts/validation/IRR-PWC_sintel.sh",
"chars": 785,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR P"
},
{
"path": "scripts/validation/IRR-PWC_things3d.sh",
"chars": 779,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR P"
},
{
"path": "scripts/validation/flownet1s.sh",
"chars": 777,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR "
},
{
"path": "scripts/validation/flownet1s_irr1.sh",
"chars": 790,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR "
},
{
"path": "scripts/validation/flownet1s_irr2.sh",
"chars": 790,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR "
},
{
"path": "scripts/validation/pwcnet.sh",
"chars": 750,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR P"
},
{
"path": "scripts/validation/pwcnet_irr.sh",
"chars": 758,
"preview": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR P"
},
{
"path": "tools.py",
"chars": 15218,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
},
{
"path": "utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "utils/flow.py",
"chars": 5201,
"preview": "from __future__ import absolute_import, division, print_function\n\nimport numpy as np\nimport png\nimport matplotlib.colors"
},
{
"path": "utils/interpolation.py",
"chars": 10980,
"preview": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\n"
}
]
// ... and 10 more files (download for full content)
About this extraction
This page contains the full source code of the visinf/irr GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 89 files (240.7 MB), approximately 124.1k tokens, and a symbol index with 472 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.