Repository: visinf/dense-ulearn-vos
Branch: main
Commit: 88e5e3518cb0
Files: 40
Total size: 160.6 KB
Directory structure:
gitextract_9fs7eg_t/
├── .gitignore
├── LICENSE
├── README.md
├── base_trainer.py
├── configs/
│ ├── kinetics.yaml
│ ├── oxuva.yaml
│ ├── tracknet.yaml
│ └── ytvos.yaml
├── core/
│ ├── __init__.py
│ └── config.py
├── datasets/
│ ├── __init__.py
│ ├── dataloader_base.py
│ ├── dataloader_infer.py
│ ├── dataloader_seg.py
│ ├── dataloader_video.py
│ └── daugm_video.py
├── infer_vos.py
├── labelprop/
│ ├── common.py
│ └── crw.py
├── launch/
│ ├── infer_vos.sh
│ ├── train.sh
│ └── utils.bash
├── models/
│ ├── __init__.py
│ ├── base.py
│ ├── framework.py
│ ├── net.py
│ └── resnet18.py
├── opts.py
├── requirements.txt
├── train.py
└── utils/
├── checkpoints.py
├── collections.py
├── davis2017.py
├── davis2017_metrics.py
├── davis2017_utils.py
├── palette.py
├── palette_davis.py
├── stat_manager.py
├── sys_tools.py
└── timer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
*.pyc
*.sw*
data/
libs/
models/pretrained
logs
snapshots
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2021 TU Darmstadt
Author: Nikita Araslanov
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Dense Unsupervised Learning for Video Segmentation
[](https://opensource.org/licenses/Apache-2.0)
[](https://pytorch.org/)
This repository contains the official implementation of our paper:
**Dense Unsupervised Learning for Video Segmentation**<br>
[Nikita Araslanov](https://arnike.github.io), [Simone Schaub-Mayer](https://schaubsi.github.io) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/visinf/team_members/sroth/sroth.en.jsp)<br>
To appear at NeurIPS*2021. [[paper](https://openreview.net/pdf?id=i8kfkuiCJCI)] [[supp](https://openreview.net/attachment?id=i8kfkuiCJCI&name=supplementary_material)] [[talk](https://youtu.be/tSBWZ6nYld0)] [[example results](https://youtu.be/BqVtZJSLOzg)] [[arXiv](https://arxiv.org/abs/2111.06265)]
| <img src="assets/examples.gif" alt="drawing" width="420"/><br> |
|:--:|
| <p align="left">We efficiently learn spatio-temporal correspondences <br> without any supervision, and achieve state-of-the-art <br>accuracy of video object segmentation.</p> |
Contact: Nikita Araslanov *fname.lname* (at) visinf.tu-darmstadt.de
---
## Installation
**Requirements.** To reproduce our results, we recommend Python >=3.6, PyTorch >=1.4, CUDA >=10.0. At least one Titan X GPUs (12GB) or equivalent is required.
The code was primarily developed under PyTorch 1.8 on a single A100 GPU.
The following steps will set up a local copy of the repository.
1. Create conda environment:
```
conda create --name dense-ulearn-vos
source activate dense-ulearn-vos
```
2. Install PyTorch >=1.4 (see [PyTorch instructions](https://pytorch.org/get-started/locally/)). For example on Linux, run:
```
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
```
3. Install the dependencies:
```
pip install -r requirements.txt
```
4. Download the data:
| Dataset | Website | Target directory with video sequences |
|:-:|:-:|:--|
| YouTube-VOS | [Link](https://competitions.codalab.org/competitions/19544#participate-get-data) | `data/ytvos/train/JPEGImages/` |
| OxUvA | [Link](https://oxuva.github.io/long-term-tracking-benchmark/) | `data/OxUvA/images/dev/` |
| TrackingNet | [Link](https://github.com/SilvioGiancola/TrackingNet-devkit) | `data/tracking/train/jpegs/` |
| Kinetics-400 | [Link](https://deepmind.com/research/open-source/kinetics) | `data/kinetics400/video_jpeg/train/` |
The last column in this table specifies a path to subdirectories (relative to the project root) containing images of video frames.
You can obviously use a different path structure.
In this case, you will need to adjust the paths in `data/filelists/` for every dataset accordingly.
5. Download filelists:
```
cd data/filelists
bash download.sh
```
This will download lists of training and validation paths for all datasets.
## Training
We following bash script will train a ResNet-18 model from scratch on one of the four supported datasets (see above):
```
bash ./launch/train.sh [ytvos|oxuva|track|kinetics]
```
We also provide our final models for download.
| Dataset | Mean J&F (DAVIS-2017) | Link | MD5 |
|---|:-:|:--:|---|
| OxUvA | 65.3 | [oxuva_e430_res4.pth (132M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/oxuva_e430_res4.pth) | `af541[...]d09b3` |
| YouTube-VOS | 69.3 | [ytvos_e060_res4.pth (132M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/ytvos_e060_res4.pth) | `c3ae3[...]55faf` |
| TrackingNet | 69.4 | [trackingnet_e088_res4.pth (88M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/trackingnet_e088_res4.pth) | `3e7e9[...]95fa9` |
| Kinetics-400 | 68.7 | [kinetics_e026_res4.pth (88M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/kinetics_e026_res4.pth) | `086db[...]a7d98` |
## Inference and evaluation
### Inference
To run the inference use `launch/infer_vos.sh`:
```
bash ./launch/infer_vos.sh [davis|ytvos]
```
The first argument selects the validation dataset to use (`davis` for DAVIS-2017; `ytvos` for YouTube-VOS).
The bash variables declared in the script further help to set up the paths for reading the data and the pre-trained models as well as the output directory:
* `EXP`, `RUN_ID` and `SNAPSHOT` determine the pre-trained model to load.
* `VER` specifies a suffix for the output directory (in case you would like to experiment with different configurations for label propagation).
Please, refer to `launch/infer_vos.sh` for their usage.
The inference script will create two directories with the result: `[res3|res4|key]_vos` and `[res3|res4|key]_vis`, where the prefix corresponds to the codename of the output CNN layer used in the evaluation (selected in `infer_vos.sh` using `KEY` variable).
The `vos`-directory contains the segmentation result ready for evaluation; the `vis`-directory produces the results for visualisation purposes.
You can optionally disable generating the visualisation by setting `VERBOSE=False` in `infer_vos.py`.
### Evaluation: DAVIS-2017
Please use the official [evaluation package](https://github.com/davisvideochallenge/davis2017-evaluation).
Install the repository, then simply run:
```
python evaluation_method.py --task semi-supervised --davis_path data/davis2017 --results_path <path-to-vos-directory>
```
### Evaluation: YouTube-VOS 2018
Please use the official [CodaLab evaluation server](https://competitions.codalab.org/competitions/19544#participate-submit_results).
To create the submission, rename the `vos`-directory to `Annotations` and compress it to `Annotations.zip` for uploading.
## Acknowledgements
We thank PyTorch contributors and [Allan Jabri](https://ajabri.github.io) for releasing [their implementation](https://github.com/ajabri/videowalk) of the label propagation.
## Citation
We hope you find our work useful. If you would like to acknowledge it in your project, please use the following citation:
```
@inproceedings{Araslanov:2021:DUL,
author = {Araslanov, Nikita and Simone Schaub-Mayer and Roth, Stefan},
title = {Dense Unsupervised Learning for Video Segmentation},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
volume = {34},
year = {2021}
}
```
================================================
FILE: base_trainer.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
import torch
import math
import numpy as np
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torch.optim.optimizer import Optimizer
from utils.checkpoints import Checkpoint
from utils.palette_davis import palette as palette_davis
from PIL import Image
from matplotlib import cm
class BaseTrainer(object):
def __init__(self, args, cfg):
self.args = args
self.cfg = cfg
self.start_epoch = 0
self.best_score = -1e16
self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 3)
logdir = os.path.join(args.logdir, 'train')
self.writer = SummaryWriter(logdir)
def checkpoint_best(self, score, epoch, temp):
if score > self.best_score:
print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch))
self.best_score = score
self.checkpoint.checkpoint(score, epoch, temp)
return True
return False
@staticmethod
def get_optim(params, cfg):
if not hasattr(torch.optim, cfg.OPT):
print("Optimiser {} not supported".format(cfg.OPT))
raise NotImplementedError
optim = getattr(torch.optim, cfg.OPT)
if cfg.OPT == 'Adam':
print("Using Adam >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY))
upd = torch.optim.Adam(params, lr=cfg.LR, \
betas=(cfg.BETA1, 0.999), \
weight_decay=cfg.WEIGHT_DECAY)
elif cfg.OPT == 'SGD':
print("Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY))
upd = torch.optim.SGD(params, lr=cfg.LR, \
momentum=cfg.MOMENTUM, \
nesterov=cfg.OPT_NESTEROV, \
weight_decay=cfg.WEIGHT_DECAY)
else:
upd = optim(params, lr=cfg.LR)
upd.zero_grad()
return upd
@staticmethod
def set_lr(optim, lr):
for param_group in optim.param_groups:
param_group['lr'] = lr
def _downsize(self, x, mode="bilinear"):
x = x.float()
if x.dim() == 3:
x = x.unsqueeze(1)
scale = min(*self.cfg.TB.IM_SIZE) / min(x.shape[-1], x.shape[-2])
if mode == "nearest":
x = F.interpolate(x, scale_factor=scale, mode="nearest")
else:
x = F.interpolate(x, scale_factor=scale, mode=mode, align_corners=True)
return x.squeeze(1)
def _visualise_seg(self, epoch, outs, writer, tag, S = 5):
def with_frame(image, mask, alpha=0.3):
return alpha * image + (1 - alpha) * mask
frames = outs["frames"][::S]
frames_norm = self.denorm(frames.cpu().clone())
frames_down = self._downsize(frames_norm)
T,C,h,w = frames_down.shape
visuals = []
visuals.append(frames_down)
if "masks_gt" in outs:
mask_rgb_gt = self._apply_cmap(outs["masks_gt"][::S].cpu(), palette_davis, rand=False)
mask_rgb_gt = self._downsize(mask_rgb_gt)
mask_rgb_gt = with_frame(frames_down, mask_rgb_gt)
visuals.append(mask_rgb_gt)
mask_rgb_idx = self._apply_cmap(outs["masks_pred_idx"][::S].cpu(), palette_davis, rand=False)
mask_rgb_idx = self._downsize(mask_rgb_idx)
mask_rgb_idx = with_frame(frames_down, mask_rgb_idx)
visuals.append(mask_rgb_idx)
#if "masks_pred_conf" in outs:
conf = self._downsize(outs["masks_pred_conf"][::S].cpu())
conf_rgb = self._error_rgb(conf, cm.get_cmap("plasma"), frames_down, 0.3)
visuals.append(conf_rgb)
visuals = [x.float() for x in visuals]
visuals = torch.cat(visuals, -1)
self._visualise_grid(writer, visuals, epoch, tag)
def _visualise(self, epoch, outs, T, writer, tag):
visuals = []
def overlay(mask, image, alpha=0.3):
return alpha * image + (1 - alpha) * mask
frames_orig = outs["frames_orig"]
frames_orig = self.denorm(frames_orig.cpu().clone())
frames_orig = self._downsize(frames_orig)
visuals.append(frames_orig)
frames = outs["frames"]
frames_norm = self.denorm(frames.cpu().clone())
frames_down = self._downsize(frames_norm)
if "grid_mask" in outs:
val_mask = outs["grid_mask"]
val_mask = self._downsize(val_mask)
val_mask = val_mask.unsqueeze(1).expand(-1,3,-1,-1).cpu()
val_mask = overlay(val_mask, frames_orig)
visuals.append(val_mask)
if "map_target" in outs:
val = outs["map_target"]
val = self._apply_cmap(val)
val = self._downsize(val, "nearest")
val = overlay(val, frames_down)
visuals.append(val)
if "map_soft" in outs:
val = outs["map_soft"]
val = self._mask_rgb(val)
val = self._downsize(val)
visuals.append(val)
visuals.append(frames_down)
frames2 = outs["frames2"]
frames2_norm = self.denorm(frames2.cpu().clone())
frames2_down = self._downsize(frames2_norm)
visuals.append(frames2_down)
if "map" in outs:
val = outs["map"]
val = self._apply_cmap(val)
val = self._downsize(val, "nearest").cpu()
val = overlay(val, frames2_down)
visuals.append(val)
if "map_target_soft" in outs:
val = outs["map_target_soft"]
val = self._mask_rgb(val)
val = self._downsize(val)
visuals.append(val)
# embedding error mask
if "error_map" in outs:
err_mask = outs["error_map"]
err_mask = (err_mask - err_mask.min()) / (err_mask.max() - err_mask.min() + 1e-8)
err_mask_rgb = self._error_rgb(err_mask, cmap=cm.get_cmap("plasma"), alpha=0.5)
err_mask_rgb = self._downsize(err_mask_rgb)
visuals.append(err_mask_rgb)
if "aff_mask1" in outs:
aff_mask = outs["aff_mask1"].unsqueeze(1).expand(-1,3,-1,-1).cpu()
aff_mask = self._downsize(aff_mask)
aff_frames = frames_orig.clone()
aff_frames[::T] = overlay(aff_mask, aff_frames[::T], 0.5)
visuals.append(aff_frames)
aff_mask1 = self._error_rgb(outs["aff11"], cm.get_cmap("inferno"))
aff_mask1 = self._downsize(aff_mask1)
aff_mask1 = overlay(aff_mask1, frames_orig, 0.3)
visuals.append(aff_mask1)
aff_mask2 = self._error_rgb(outs["aff12"], cm.get_cmap("inferno"))
aff_mask2 = self._downsize(aff_mask2)
aff_mask2 = overlay(aff_mask2, frames2_down, 0.3)
visuals.append(aff_mask2)
if "aff_mask2" in outs:
aff_mask = outs["aff_mask2"].unsqueeze(1).expand(-1,3,-1,-1).cpu()
aff_mask = self._downsize(aff_mask)
aff_frames = frames_down.clone()
aff_frames[::T] = overlay(aff_mask, aff_frames[::T], 0.5)
visuals.append(aff_frames)
aff_mask1 = self._error_rgb(outs["aff21"], cm.get_cmap("inferno"))
aff_mask1 = self._downsize(aff_mask1)
aff_mask1 = overlay(aff_mask1, frames_orig, 0.3)
visuals.append(aff_mask1)
aff_mask2 = self._error_rgb(outs["aff22"], cm.get_cmap("inferno"))
aff_mask2 = self._downsize(aff_mask2)
aff_mask2 = overlay(aff_mask2, frames2_down, 0.3)
visuals.append(aff_mask2)
visuals = [x.cpu().float() for x in visuals]
visuals = torch.cat(visuals, -1)
self._visualise_grid(writer, visuals, epoch, tag, 4 * T)
def save_vis_batch(self, key, batch):
if self.vis_batch is None:
self.vis_batch = {}
if key in self.vis_batch:
return
batch_items = []
for el in batch:
el = el.clone().cpu() if torch.is_tensor(el) else el
batch_items.append(el)
self.vis_batch[key] = batch_items
def has_vis_batch(self, key):
return (not self.vis_batch is None and \
key in self.vis_batch)
def _mask_rgb(self, masks, image_norm=None, palette=None, alpha=0.3):
if palette is None:
palette = self.loader.dataset.palette
# visualising masks
masks_conf, masks_idx = torch.max(masks, 1)
masks_conf = masks_conf - F.relu(masks_conf - 1, 0)
masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), palette, mask_conf=masks_conf.cpu())
if not image_norm is None:
return alpha * image_norm + (1 - alpha) * masks_idx_rgb
return masks_idx_rgb
def _apply_cmap(self, mask_idx, palette=None, mask_conf=None, rand=True):
if palette is None:
palette = self.loader.dataset.palette
ignore_mask = (mask_idx == -1).cpu()
# cycle
if rand:
memsize = self.cfg.TRAIN.BATCH_SIZE * self.cfg.MODEL.GRID_SIZE**2
mask_idx = ((mask_idx + 1) * 123) % memsize
# convert mask to RGB
mask = mask_idx.cpu().numpy().astype(np.uint32)
mask_rgb = palette(mask)
mask_rgb = torch.from_numpy(mask_rgb[:,:,:,:3])
mask_rgb[ignore_mask] *= 0
mask_rgb = mask_rgb.permute(0,3,1,2)
if not mask_conf is None:
# entropy
mask_rgb *= mask_conf[:, None, :, :]
return mask_rgb
def _error_rgb(self, error_mask, cmap = cm.get_cmap('jet'), image=None, alpha=0.3):
error_np = error_mask.cpu().numpy()
# remove alpha channel
error_rgb = cmap(error_np)[:, :, :, :3]
error_rgb = np.transpose(error_rgb, (0,3,1,2))
error_rgb = torch.from_numpy(error_rgb)
if not image is None:
return alpha * image + (1 - alpha) * error_rgb
return error_rgb
def _visualise_grid(self, writer, x_all, t, tag, T=1):
# adding the labels to images
bs, ch, h, w = x_all.size()
x_all_new = torch.zeros(T, ch, h, w)
for b in range(bs):
x_all_new[b % T] = x_all[b]
if (b + 1) % T == 0:
summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9).numpy()
writer.add_image(tag + "_{:02d}".format(b // T), summary_grid, t)
x_all_new.zero_()
def visualise_results(self, epoch, writer, tag, step_func):
# visualising
self.net.eval()
with torch.no_grad():
step_func(epoch, self.vis_batch[tag], \
train=False, visualise=True, \
writer=writer, tag=tag)
================================================
FILE: configs/kinetics.yaml
================================================
DATASET:
ROOT: "data"
SMALLEST_RANGE: [256,256]
RND_CROP: False
RND_ZOOM: True
RND_ZOOM_RANGE: [.5, 1.]
GUIDED_HFLIP: True
VIDEO_LEN: 5
FRAME_GAP: 10
MP4: True
TRAIN:
BATCH_SIZE: 16
NUM_EPOCHS: 80
TASK: "kinetics400"
MODEL:
LR: 0.001
CE_REF: 0.1
OPT: "SGD"
LR_SCHEDULER: "step"
LR_GAMMA: 0.5
LR_STEP: 20
WEIGHT_DECAY: 0.0005
LOG:
ITER_TRAIN: 8
ITER_VAL: 2
TEST:
KNN: 10
CXT_SIZE: 20
RADIUS: 12
TEMP: 0.05
TB:
IM_SIZE: [196, 196]
================================================
FILE: configs/oxuva.yaml
================================================
DATASET:
ROOT: "data"
SMALLEST_RANGE: [256, 320]
RND_CROP: True
RND_ZOOM: True
RND_ZOOM_RANGE: [.5, 1.]
GUIDED_HFLIP: True
VIDEO_LEN: 5
FRAME_GAP: 10
TRAIN:
BATCH_SIZE: 16
NUM_EPOCHS: 500
TASK: "OxUvA"
MODEL:
LR: 0.0001
OPT: "Adam"
LR_SCHEDULER: "step"
LR_GAMMA: 0.5
LR_STEP: 100
WEIGHT_DECAY: 0.0005
LOG:
ITER_TRAIN: 20
ITER_VAL: 10
TEST:
KNN: 10
CXT_SIZE: 20
RADIUS: 12
TEMP: 0.05
TB:
IM_SIZE: [196, 196]
================================================
FILE: configs/tracknet.yaml
================================================
DATASET:
ROOT: "data"
SMALLEST_RANGE: [256,256]
RND_CROP: False
RND_ZOOM: True
RND_ZOOM_RANGE: [.5, 1.]
GUIDED_HFLIP: True
VIDEO_LEN: 5
FRAME_GAP: 10
TRAIN:
BATCH_SIZE: 16
NUM_EPOCHS: 250
TASK: "TrackingNet"
BLOCK_BN: False
STOP_GRAD: True
MODEL:
LR: 0.001
CE_REF: 0.1
OPT: "SGD"
LR_SCHEDULER: "step"
LR_GAMMA: 0.9
LR_STEP: 20
WEIGHT_DECAY: 0.0005
LOG:
ITER_TRAIN: 10
ITER_VAL: 2
TEST:
KNN: 10
CXT_SIZE: 20
RADIUS: 12
TEMP: 0.05
TB:
IM_SIZE: [196, 196]
================================================
FILE: configs/ytvos.yaml
================================================
DATASET:
ROOT: "data"
SMALLEST_RANGE: [256, 320]
RND_CROP: True
RND_ZOOM: True
RND_ZOOM_RANGE: [.5, 1.]
GUIDED_HFLIP: True
VIDEO_LEN: 5
FRAME_GAP: 2
TRAIN:
BATCH_SIZE: 16
NUM_EPOCHS: 500
TASK: "YTVOS"
MODEL:
LR: 0.0001
OPT: "Adam"
LR_SCHEDULER: "step"
LR_GAMMA: 0.5
LR_STEP: 100
WEIGHT_DECAY: 0.0005
LOG:
ITER_TRAIN: 20
ITER_VAL: 10
TEST:
KNN: 10
CXT_SIZE: 20
RADIUS: 12
TEMP: 0.05
TB:
IM_SIZE: [196, 196]
================================================
FILE: core/__init__.py
================================================
================================================
FILE: core/config.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import yaml
import six
import os
import os.path as osp
import copy
from ast import literal_eval
import numpy as np
from packaging import version
from utils.collections import AttrDict
__C = AttrDict()
# Consumers can get config by:
# from fast_rcnn_config import cfg
cfg = __C
__C.NUM_GPUS = 1
# Random note: avoid using '.ON' as a config key since yaml converts it to True;
# prefer 'ENABLED' instead
# ---------------------------------------------------------------------------- #
# Training schedule options
# ---------------------------------------------------------------------------- #
__C.TRAIN = AttrDict()
__C.TRAIN.BATCH_SIZE = 20
__C.TRAIN.NUM_EPOCHS = 15
__C.TRAIN.NUM_WORKERS = 4
__C.TRAIN.TASK = "YTVOS"
__C.TRAIN.BLOCK_BN = False
__C.TRAIN.STOP_GRAD = True
# ---------------------------------------------------------------------------- #
# Dataset options (+ augmentation options)
# ---------------------------------------------------------------------------- #
__C.DATASET = AttrDict()
__C.DATASET.CROP_SIZE = [256,256]
__C.DATASET.SMALLEST_RANGE = [256,320]
__C.DATASET.RND_CROP = False
__C.DATASET.RND_HFLIP = True
__C.DATASET.MP4 = False
__C.DATASET.RND_ZOOM = False
__C.DATASET.RND_ZOOM_RANGE = [.5,1.]
__C.DATASET.GUIDED_HFLIP = True
__C.DATASET.ROOT = "data"
__C.DATASET.VIDEO_LEN = 5
__C.DATASET.FRAME_GAP = 5
__C.DATASET.VAL_FRAME_GAP = 2
# size of the augmentation
# (how many augmented copies to create for 1 image)
__C.DATASET.NUM_CLASSES = 7
# inference-time parameters
__C.TEST = AttrDict()
__C.TEST.RADIUS = 12
__C.TEST.TEMP = 0.05
__C.TEST.KNN = 10
__C.TEST.CXT_SIZE = 20
__C.TEST.INPUT_SIZE = -1
__C.TEST.KEY = "res4"
# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
__C.MODEL = AttrDict()
__C.MODEL.ARCH = 'resnet18'
__C.MODEL.LR_SCHEDULER = 'poly'
__C.MODEL.LR_STEP = 5
__C.MODEL.LR_GAMMA = 0.1 # divide by 10 every LR_STEP epochs
__C.MODEL.LR_POWER = 1.0
__C.MODEL.LR_SCHED_USE_EPOCH = True
__C.MODEL.OPT = 'Adam'
__C.MODEL.OPT_NESTEROV = False
__C.MODEL.LR = 3e-4
__C.MODEL.BETA1 = 0.5
__C.MODEL.MOMENTUM = 0.9
__C.MODEL.WEIGHT_DECAY = 1e-5
__C.MODEL.REMOVE_LAYERS = []
__C.MODEL.FEATURE_DIM = 128
__C.MODEL.GRID_SIZE = 8
__C.MODEL.GRID_SIZE_REF = 4
__C.MODEL.CE_REF = 0.1
# ---------------------------------------------------------------------------- #
# Options for refinement
# ---------------------------------------------------------------------------- #
__C.LOG = AttrDict()
__C.LOG.ITER_VAL = 1
__C.LOG.ITER_TRAIN = 10
# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
__C.TB = AttrDict()
__C.TB.IM_SIZE = (196, 196) # image resolution
# [Infered value]
__C.CUDA = False
__C.DEBUG = False
# [Infered value]
__C.PYTORCH_VERSION_LESS_THAN_040 = False
def assert_and_infer_cfg(make_immutable=True):
"""Call this function in your script after you have finished setting all cfg
values that are necessary (e.g., merging a config from a file, merging
command line config options, etc.). By default, this function will also
mark the global cfg as immutable to prevent changing the global cfg settings
during script execution (which can lead to hard to debug errors or code
that's harder to understand than is necessary).
"""
if make_immutable:
cfg.immutable(True)
def merge_cfg_from_file(cfg_filename):
"""Load a yaml config file and merge it into the global config."""
with open(cfg_filename, 'r') as f:
yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))
_merge_a_into_b(yaml_cfg, __C)
cfg_from_file = merge_cfg_from_file
def merge_cfg_from_cfg(cfg_other):
"""Merge `cfg_other` into the global config."""
_merge_a_into_b(cfg_other, __C)
def merge_cfg_from_list(cfg_list):
"""Merge config keys, values in a list (e.g., from command line) into the
global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.
"""
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
key_list = full_key.split('.')
d = __C
for subkey in key_list[:-1]:
assert subkey in d, 'Non-existent key: {}'.format(full_key)
d = d[subkey]
subkey = key_list[-1]
assert subkey in d, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(
value, d[subkey], subkey, full_key
)
d[subkey] = value
cfg_from_list = merge_cfg_from_list
def _merge_a_into_b(a, b, stack=None):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict'
assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict'
for k, v_ in a.items():
full_key = '.'.join(stack) + '.' + k if stack is not None else k
# a must specify keys that are in b
if k not in b:
raise KeyError('Non-existent config key: {}'.format(full_key))
v = copy.deepcopy(v_)
v = _decode_cfg_value(v)
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
# Recursively merge dicts
if isinstance(v, AttrDict):
try:
stack_push = [k] if stack is None else stack + [k]
_merge_a_into_b(v, b[k], stack=stack_push)
except BaseException:
raise
else:
b[k] = v
def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to AttrDict objects
if isinstance(v, dict):
return AttrDict(v)
# All remaining processing is only applied to strings
if not isinstance(v, six.string_types):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v
def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
"""Checks that `value_a`, which is intended to replace `value_b` is of the
right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
# The types must match (with some exceptions)
type_b = type(value_b)
type_a = type(value_a)
if type_a is type_b:
return value_a
# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, np.ndarray):
value_a = np.array(value_a, dtype=value_b.dtype)
elif isinstance(value_b, six.string_types):
value_a = str(value_a)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
else:
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(type_b, type_a, value_b, value_a, full_key)
)
return value_a
================================================
FILE: datasets/__init__.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import torch
from torch.utils import data
from .dataloader_seg import DataSeg
from .dataloader_video import DataVideo, DataVideoKinetics
def get_sets(task):
# fetch the names of data lists depending on the task
sets = {}
if task == "OxUvA":
sets["train_video"] = "train_oxuva"
sets["val_video"] = "val_davis2017_480p"
sets["val_video_seg"] = "val2_davis2017_480p"
elif task == "YTVOS":
sets["train_video"] = "train_ytvos"
sets["val_video"] = "val_davis2017_480p"
sets["val_video_seg"] = "val2_davis2017_480p"
elif task == "TrackingNet":
sets["train_video"] = "train_tracking"
sets["val_video"] = "val_davis2017_480p"
sets["val_video_seg"] = "val2_davis2017_480p"
elif task == "kinetics400":
sets["train_video"] = "train_kinetics400"
sets["val_video"] = "val_davis2017_480p"
sets["val_video_seg"] = "val2_davis2017_480p"
else:
raise NotImplementedError("Dataset '{}' not recognised.".format(task))
return sets
def get_dataloader(args, cfg, split):
assert split in ("train", "val")
task = cfg.TRAIN.TASK
data_sets = get_sets(cfg.TRAIN.TASK)
# total batch size: # of GPUs * batch size per GPU
batch_size = cfg.TRAIN.BATCH_SIZE
kwargs = {'pin_memory': True, 'num_workers': args.workers}
print("Dataloader: # workers {}".format(args.workers))
def _dataloader(dataset, batch_size, shuffle=True, drop_last=False):
return data.DataLoader(dataset, batch_size=batch_size, \
shuffle=shuffle, drop_last=drop_last, **kwargs)
print("Split: ", split)
if split == "train":
VideoLoader = DataVideoKinetics if cfg.DATASET.MP4 else DataVideo
dataset_video = VideoLoader(cfg, data_sets["train_video"])
return _dataloader(dataset_video, batch_size, drop_last=True)
else:
dataset_video = DataVideo(cfg, data_sets["val_video"], val=True)
dataset_video_seg = DataSeg(cfg, data_sets["val_video_seg"])
return {"val_video": _dataloader(dataset_video, batch_size, shuffle=False), \
"val_video_seg": _dataloader(dataset_video_seg, 1, shuffle=False)}
================================================
FILE: datasets/dataloader_base.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import torch.utils.data as data
from utils.palette import custom_palette
class DLBase(data.Dataset):
def __init__(self, *args, **kwargs):
super(DLBase, self).__init__(*args, **kwargs)
# RGB
self.MEAN = [0.485, 0.456, 0.406]
self.STD = [0.229, 0.224, 0.225]
self._init_means()
def _init_means(self):
self.MEAN255 = [255.*x for x in self.MEAN]
self.STD255 = [255.*x for x in self.STD]
def _init_palette(self, num_classes):
self.palette = custom_palette(num_classes)
def get_palette(self):
return self.palette
def remove_labels(self, mask):
# Remove labels not in training
for ignore_label in self.ignore_labels:
mask[mask == ignore_label] = 255
return mask.long()
================================================
FILE: datasets/dataloader_infer.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as tf
from .dataloader_base import DLBase
class DataSeg(DLBase):
def __init__(self, cfg, split, ignore_labels=[], \
root=os.path.expanduser('./data'), renorm=False):
super(DataSeg, self).__init__()
self.cfg = cfg
self.root = root
self.split = split
self.ignore_labels = ignore_labels
self._init_palette(self.cfg.DATASET.NUM_CLASSES)
# train/val/test splits are pre-cut
split_fn = os.path.join(self.root, self.split + ".txt")
assert os.path.isfile(split_fn)
self.sequence_ids = []
self.sequence_names = []
def add_sequence(name):
vlen = len(self.images)
assert vlen >= cfg.DATASET.VIDEO_LEN, \
"Detected video shorter [{}] than training length [{}]".format(vlen, \
cfg.DATASET.VIDEO_LEN)
self.sequence_ids.append(vlen)
self.sequence_names.append(name)
return vlen
self.images = []
self.masks = []
self.flags = []
token = None
with open(split_fn, "r") as lines:
for line in lines:
_flag, _image, _mask = line.strip("\n").split(' ')
# save every frame
#_flag = 1
self.flags.append(int(_flag))
_image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/'))
assert os.path.isfile(_image), '%s not found' % _image
# each sequence may have a different length
# do some book-keeping e.g. to ensure we have
# sequences long enough for subsequent sampling
_token = _image.split("/")[-2] # parent directory
# sequence ID is in the filename
#_token = os.path.basename(_image).split("_")[0]
if token != _token:
if not token is None:
add_sequence(token)
token = _token
self.images.append(_image)
if _mask is None:
self.masks.append(None)
else:
_mask = os.path.join(cfg.DATASET.ROOT, _mask.lstrip('/'))
#assert os.path.isfile(_mask), '%s not found' % _mask
self.masks.append(_mask)
# update the last sequence
# returns the total amount of frames
add_sequence(token)
print("Loaded {} sequences".format(len(self.sequence_ids)))
# definint data augmentation:
print("Dataloader: {}".format(split), " #", len(self.images))
print("\t {}: no augmentation".format(split))
self.tf = tf.Compose([tf.ToTensor(), tf.Normalize(mean=self.MEAN, std=self.STD)])
self._num_samples = len(self.images)
def __len__(self):
return len(self.sequence_ids)
def _mask2tensor(self, mask, num_classes=6):
h,w = mask.shape
ones = torch.ones(1,h,w)
zeros = torch.zeros(num_classes,h,w)
max_idx = mask.max()
assert max_idx < num_classes, "{} >= {}".format(max_idx, num_classes)
return zeros.scatter(0, mask[None, ...], ones)
def denorm(self, image):
if image.dim() == 3:
assert image.dim() == 3, "Expected image [CxHxW]"
assert image.size(0) == 3, "Expected RGB image [3xHxW]"
for t, m, s in zip(image, self.MEAN, self.STD):
t.mul_(s).add_(m)
elif image.dim() == 4:
# batch mode
assert image.size(1) == 3, "Expected RGB image [3xHxW]"
for t, m, s in zip((0,1,2), self.MEAN, self.STD):
image[:, t, :, :].mul_(s).add_(m)
return image
def __getitem__(self, index):
seq_to = self.sequence_ids[index]
seq_from = 0 if index == 0 else self.sequence_ids[index - 1]
image0 = Image.open(self.images[seq_from])
w,h = image0.size
images, masks, fns, flags = [], [], [], []
tracks = torch.LongTensor(self.cfg.DATASET.NUM_CLASSES).fill_(-1)
masks = torch.LongTensor(self.cfg.DATASET.NUM_CLASSES, h, w).zero_()
known_ids = set()
for t in range(seq_from, seq_to):
t0 = t - seq_from
image = Image.open(self.images[t]).convert('RGB')
fns.append(os.path.basename(self.images[t].replace(".jpg", "")))
flags.append(self.flags[t])
if os.path.isfile(self.masks[t]):
mask = Image.open(self.masks[t])
mask = torch.from_numpy(np.array(mask, np.long, copy=False))
unique_ids = np.unique(mask)
for oid in unique_ids:
if not oid in known_ids:
tracks[oid] = t0
known_ids.add(oid)
masks[oid] = (mask == oid).long()
else:
mask = Image.new('L', image.size)
image = self.tf(image)
images.append(image)
images = torch.stack(images, 0)
seq_name = self.sequence_names[index]
flags = torch.LongTensor(flags)
return images, images, masks, tracks, len(known_ids), fns, flags, seq_name
================================================
FILE: datasets/dataloader_seg.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
import torch
from PIL import Image
import numpy as np
import torch.utils.data as data
import torchvision.transforms as tf
from .dataloader_base import DLBase
class DataSeg(DLBase):
def __init__(self, cfg, split, ignore_labels=[], \
root=os.path.expanduser('./data'), renorm=False):
super(DataSeg, self).__init__()
self.cfg = cfg
self.root = root
self.split = split
self.ignore_labels = ignore_labels
self._init_palette(cfg.DATASET.NUM_CLASSES)
# train/val/test splits are pre-cut
split_fn = os.path.join(self.root, "filelists", self.split + ".txt")
assert os.path.isfile(split_fn)
self.sequence_ids = []
self.sequence_names = []
def add_sequence(name):
vlen = len(self.images)
assert vlen >= cfg.DATASET.VIDEO_LEN, \
"Detected video shorter [{}] than training length [{}]".format(vlen, \
cfg.DATASET.VIDEO_LEN)
self.sequence_ids.append(vlen)
self.sequence_names.append(name)
return vlen
self.images = []
self.masks = []
token = None
with open(split_fn, "r") as lines:
for line in lines:
_image = line.strip("\n").split(' ')
_mask = None
if len(_image) == 2:
_image, _mask = _image
else:
assert len(_image) == 1
_image = _image[0]
_image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/'))
assert os.path.isfile(_image), '%s not found' % _image
self.images.append(_image)
# each sequence may have a different length
# do some book-keeping e.g. to ensure we have
# sequences long enough for subsequent sampling
_token = _image.split("/")[-2] # parent directory
# sequence ID is in the filename
#_token = os.path.basename(_image).split("_")[0]
if token != _token:
if not token is None:
add_sequence(token)
token = _token
if _mask is None:
self.masks.append(None)
else:
_mask = os.path.join(cfg.DATASET.ROOT, _mask.lstrip('/'))
assert os.path.isfile(_mask), '%s not found' % _mask
self.masks.append(_mask)
# update the last sequence
# returns the total amount of frames
add_sequence(token)
print("Loaded {} sequences".format(len(self.sequence_ids)))
# definint data augmentation:
print("Dataloader: {}".format(split), " #", len(self.images))
print("\t {}: no augmentation".format(split))
self.tf = tf.Compose([tf.ToTensor(), tf.Normalize(mean=self.MEAN, std=self.STD)])
self._num_samples = len(self.images)
def __len__(self):
return len(self.sequence_ids)
def denorm(self, image):
if image.dim() == 3:
assert image.dim() == 3, "Expected image [CxHxW]"
assert image.size(0) == 3, "Expected RGB image [3xHxW]"
for t, m, s in zip(image, self.MEAN, self.STD):
t.mul_(s).add_(m)
elif image.dim() == 4:
# batch mode
assert image.size(1) == 3, "Expected RGB image [3xHxW]"
for t, m, s in zip((0,1,2), self.MEAN, self.STD):
image[:, t, :, :].mul_(s).add_(m)
return image
def _mask2tensor(self, mask, num_classes=6):
h,w = mask.shape
ones = torch.ones(1,h,w)
zeros = torch.zeros(num_classes,h,w)
max_idx = mask.max()
assert max_idx < num_classes, "{} >= {}".format(max_idx, num_classes)
return zeros.scatter(0, mask[None, ...], ones)
def __getitem__(self, index):
seq_to = self.sequence_ids[index] - 1
seq_from = 0 if index == 0 else self.sequence_ids[index-1] - 1
images, masks = [], []
n_obj = 0
for _id_ in range(seq_from, seq_to):
image = Image.open(self.images[_id_]).convert('RGB')
if self.masks[_id_] is None:
mask = Image.new('L', image.size)
else:
mask = Image.open(self.masks[_id_]) #.convert('L')
image = self.tf(image)
images.append(image)
mask = torch.from_numpy(np.array(mask, np.long, copy=False))
n_obj = max(n_obj, mask.max().item())
masks.append(self._mask2tensor(mask))
images = torch.stack(images, 0)
masks = torch.stack(masks, 0)
n_obj = torch.LongTensor([n_obj + 1]) # +1 background
seq_name = self.sequence_names[index]
return images, masks, n_obj, seq_name
================================================
FILE: datasets/dataloader_video.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
import random
import torch
import math
import glob
from PIL import Image
import numpy as np
from .dataloader_base import DLBase
import datasets.daugm_video as tf
class DataVideo(DLBase):
def __init__(self, cfg, split, val=False):
super(DataVideo, self).__init__()
self.cfg = cfg
self.split = split
self.val = val
self.cfg_frame_gap = cfg.DATASET.VAL_FRAME_GAP if val else cfg.DATASET.FRAME_GAP
self._init_palette(cfg.TRAIN.BATCH_SIZE * cfg.MODEL.GRID_SIZE**2)
# train/val/test splits are pre-cut
split_fn = os.path.join(cfg.DATASET.ROOT, "filelists", self.split + ".txt")
assert os.path.isfile(split_fn)
self.images = []
token = None # sequence token (new video when it changes)
subsequence = []
ignored = [0]
num_frames = [0]
def add_sequence():
if cfg.DATASET.VIDEO_LEN > len(subsequence):
# found a very short sequence
ignored[0] += 1
else:
# adding the subsequence
self.images.append(tuple(subsequence))
num_frames[0] += len(subsequence)
del subsequence[:]
with open(split_fn, "r") as lines:
for line in lines:
_line = line.strip("\n").split(' ')
assert len(_line) > 0, "Expected at least one path"
_image = _line[0]
# each sequence may have a different length
# do some book-keeping e.g. to ensure we have
# sequences long enough for subsequent sampling
_token = _image.split("/")[-2] # parent directory
# sequence ID is in the filename
#_token = os.path.basename(_image).split("_")[0]
if token != _token:
if not token is None:
add_sequence()
token = _token
# image 1
_image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/'))
#assert os.path.isfile(_image), '%s not found' % _image
subsequence.append(_image)
# update the last sequence
# returns the total amount of frames
add_sequence()
print("Dataloader: {}".format(split), " / Frame Gap: ", self.cfg_frame_gap)
print("Loaded {} sequences / {} ignored / {} frames".format(len(self.images), \
ignored[0], \
num_frames[0]))
self._num_samples = num_frames[0]
self._init_augm(cfg)
def _init_augm(self, cfg):
# general (unguided) affine transformations
tfs_pre = [tf.CreateMask()]
self.tf_pre = tf.Compose(tfs_pre)
# photometric noise
tfs_affine = []
# guided affine transformations
tfs_augm = []
# 1.
# general affine transformations
#
tfs_pre.append(tf.MaskScaleSmallest(cfg.DATASET.SMALLEST_RANGE))
if cfg.DATASET.RND_CROP:
tfs_pre.append(tf.MaskRandCrop(cfg.DATASET.CROP_SIZE, pad_if_needed=True))
else:
tfs_pre.append(tf.MaskCenterCrop(cfg.DATASET.CROP_SIZE))
if cfg.DATASET.RND_HFLIP:
tfs_pre.append(tf.MaskRandHFlip())
# 2.
# Guided affine transformation
#
if cfg.DATASET.GUIDED_HFLIP:
tfs_affine.append(tf.GuidedRandHFlip())
# this will add affine transformation
if cfg.DATASET.RND_ZOOM:
tfs_affine.append(tf.MaskRandScaleCrop(*cfg.DATASET.RND_ZOOM_RANGE))
self.tf_affine = tf.Compose(tfs_affine)
self.tf_affine2 = tf.Compose([tf.AffineIdentity()])
tfs_post = [tf.ToTensorMask(),
tf.Normalize(mean=self.MEAN, std=self.STD),
tf.ApplyMask(-1)]
# image to the teacher will have no noise
self.tf_post = tf.Compose(tfs_post)
def set_num_samples(self, n):
print("Re-setting # of samples: {:d} -> {:d}".format(self._num_samples, n))
self._num_samples = n
def __len__(self):
return len(self.images) #self._num_samples
def denorm(self, image):
if image.dim() == 3:
assert image.dim() == 3, "Expected image [CxHxW]"
assert image.size(0) == 3, "Expected RGB image [3xHxW]"
for t, m, s in zip(image, self.MEAN, self.STD):
t.mul_(s).add_(m)
elif image.dim() == 4:
# batch mode
assert image.size(1) == 3, "Expected RGB image [3xHxW]"
for t, m, s in zip((0,1,2), self.MEAN, self.STD):
image[:, t, :, :].mul_(s).add_(m)
return image
def _get_affine(self, params):
N = len(params)
# construct affine operator
affine = torch.zeros(N, 2, 3)
aspect_ratio = float(self.cfg.DATASET.CROP_SIZE[0]) / \
float(self.cfg.DATASET.CROP_SIZE[1])
for i, (dy,dx,alpha,scale,flip) in enumerate(params):
# R inverse
sin = math.sin(alpha * math.pi / 180.)
cos = math.cos(alpha * math.pi / 180.)
# inverse, note how flipping is incorporated
affine[i,0,0], affine[i,0,1] = flip * cos, sin * aspect_ratio
affine[i,1,0], affine[i,1,1] = -sin / aspect_ratio, cos
# T inverse Rinv * t == R^T * t
affine[i,0,2] = -1. * (cos * dx + sin * dy)
affine[i,1,2] = -1. * (-sin * dx + cos * dy)
# T
affine[i,0,2] /= float(self.cfg.DATASET.CROP_SIZE[1] // 2)
affine[i,1,2] /= float(self.cfg.DATASET.CROP_SIZE[0] // 2)
# scaling
affine[i] *= scale
return affine
def _get_affine_inv(self, affine, params):
aspect_ratio = float(self.cfg.DATASET.CROP_SIZE[0]) / \
float(self.cfg.DATASET.CROP_SIZE[1])
affine_inv = affine.clone()
affine_inv[:,0,1] = affine[:,1,0] * aspect_ratio**2
affine_inv[:,1,0] = affine[:,0,1] / aspect_ratio**2
affine_inv[:,0,2] = -1 * (affine_inv[:,0,0] * affine[:,0,2] + affine_inv[:,0,1] * affine[:,1,2])
affine_inv[:,1,2] = -1 * (affine_inv[:,1,0] * affine[:,0,2] + affine_inv[:,1,1] * affine[:,1,2])
# scaling
affine_inv /= torch.Tensor(params)[:,3].view(-1,1,1)**2
return affine_inv
def __getitem__(self, index):
# searching for the video clip ID
sequence = self.images[index] # % len(self.images)]
seqlen = len(sequence)
assert self.cfg_frame_gap > 0, "Frame gap should be positive"
t_window = self.cfg_frame_gap * self.cfg.DATASET.VIDEO_LEN
# reduce sampling gap for short clips
t_window = min(seqlen, t_window)
frame_gap = t_window // self.cfg.DATASET.VIDEO_LEN
# strided slice
frame_ids = torch.arange(t_window)[::frame_gap]
frame_ids = frame_ids[:self.cfg.DATASET.VIDEO_LEN]
assert len(frame_ids) == self.cfg.DATASET.VIDEO_LEN
# selecting a random start
index_start = random.randint(0, seqlen - frame_ids[-1] - 1)
# permuting the frames in the batch
random_ids = torch.randperm(self.cfg.DATASET.VIDEO_LEN)
# adding the offset
frame_ids = frame_ids[random_ids] + index_start
# forward sequence
images = []
for frame_id in frame_ids:
fn = sequence[frame_id]
images.append(Image.open(fn).convert('RGB'))
# 1. general transforms
frames, valid = self.tf_pre(images)
# 1.1 creating two sequences in forward/backward order
frames1, valid1 = frames[:], valid[:]
# second copy
frames2 = [f.copy() for f in frames]
valid2 = [v.copy() for v in valid]
# 2. guided affine transforms
frames1, valid1, affine_params1 = self.tf_affine(frames1, valid1)
frames2, valid2, affine_params2 = self.tf_affine2(frames2, valid2)
# convert to tensor, zero out the values
frames1 = self.tf_post(frames1, valid1)
frames2 = self.tf_post(frames2, valid2)
# converting the affine transforms
aff_reg = self._get_affine(affine_params1)
aff_main = self._get_affine(affine_params2)
aff_reg_inv = self._get_affine_inv(aff_reg, affine_params1)
aff_reg = aff_main # identity affine2_inv
aff_main = aff_reg_inv
frames1 = torch.stack(frames1, 0)
frames2 = torch.stack(frames2, 0)
assert frames1.shape == frames2.shape
return frames2, frames1, aff_main, aff_reg
class DataVideoKinetics(DataVideo):
def __init__(self, cfg, split):
super(DataVideo, self).__init__()
self.cfg = cfg
self.split = split
self._init_palette(cfg.TRAIN.BATCH_SIZE * cfg.MODEL.GRID_SIZE**2)
# train/val/test splits are pre-cut
split_fn = os.path.join(cfg.DATASET.ROOT, "filelists", self.split + ".txt")
assert os.path.isfile(split_fn)
self.videos = []
with open(split_fn, "r") as lines:
for line in lines:
_line = line.strip("\n").split(' ')
assert len(_line) > 0, "Expected at least one path"
_vid = _line[0]
# image 1
_vid = os.path.join(cfg.DATASET.ROOT, _vid.lstrip('/'))
#assert os.path.isdir(_vid), "{} does not exist".format(_vid)
self.videos.append(_vid)
# update the last sequence
# returns the total amount of frames
print("DataloaderKinetics: {}".format(split))
print("Loaded {} sequences".format(len(self.videos)))
self._init_augm(cfg)
def __len__(self):
return len(self.videos)
def __getitem__(self, index):
C = self.cfg.DATASET
path = self.videos[index]
# filenames
fns = sorted(glob.glob(path + "/*.jpg"))
total_len = len(fns)
# temporal window to consider
temp_window = min(total_len, C.VIDEO_LEN * C.FRAME_GAP)
gap = temp_window / C.VIDEO_LEN
start_frame = random.randint(0, total_len - temp_window)
images = []
for idx in range(C.VIDEO_LEN):
frame_id = start_frame + int(idx * gap)
fn = fns[frame_id]
images.append(Image.open(fn).convert('RGB'))
# 1. general transforms
frames, valid = self.tf_pre(images)
# 1.1 creating two sequences in forward/backward order
frames1, valid1 = frames[:], valid[:]
# second copy
frames2 = [f.copy() for f in frames]
valid2 = [v.copy() for v in valid]
# 2. guided affine transforms
frames1, valid1, affine_params1 = self.tf_affine(frames1, valid1)
frames2, valid2, affine_params2 = self.tf_affine2(frames2, valid2)
# convert to tensor, zero out the values
frames1 = self.tf_post(frames1, valid1)
frames2 = self.tf_post(frames2, valid2)
# converting the affine transforms
affine1 = self._get_affine(affine_params1)
affine2 = self._get_affine(affine_params2)
affine1_inv = self._get_affine_inv(affine1, affine_params1)
affine1 = affine2 # identity affine2_inv
affine2 = affine1_inv
frames1 = torch.stack(frames1, 0)
frames2 = torch.stack(frames2, 0)
assert frames1.shape == frames2.shape
return frames2, frames1, affine2, affine1
================================================
FILE: datasets/daugm_video.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import random
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as tf
import torchvision.transforms.functional as F
class Compose:
# Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()])
def __init__(self, segtransform):
self.segtransform = segtransform
def __call__(self, args, *more_args):
# allow for intermediate representations
for t in self.segtransform:
result = t(args, *more_args)
args = result[0]
more_args = result[1:]
return result
class ToTensorMask:
def __toByteTensor(self, pic):
return torch.from_numpy(np.array(pic, np.int32, copy=False))
def __call__(self, images, masks):
new_masks = []
for i, (image, mask) in enumerate(zip(images, masks)):
images[i] = F.to_tensor(image)
new_masks.append(self.__toByteTensor(mask))
return images, new_masks
class CreateMask:
"""Create mask to hold invalid pixels
(e.g. from rotations or downscaling)
"""
def __call__(self, images):
masks = []
for i, image in enumerate(images):
masks.append(Image.new("L", image.size))
return images, masks
class Normalize:
# Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
def __init__(self, mean, std=None):
if std is None:
assert len(mean) > 0
else:
assert len(mean) == len(std)
self.mean = mean
self.std = std
def __call__(self, images, masks):
for i, (image, mask) in enumerate(zip(images, masks)):
if self.std is None:
for t, m in zip(image, self.mean):
t.sub_(m)
else:
for t, m, s in zip(image, self.mean, self.std):
t.sub_(m).div_(s)
return images, masks
class ApplyMask:
def __init__(self, ignore_label):
self.ignore_label = ignore_label
def __call__(self, images, masks):
for i, (image, mask) in enumerate(zip(images, masks)):
mask = mask > 0.
images[i] *= (1. - mask.type_as(image))
return images
class GuidedRandHFlip:
def __call__(self, images, mask, affine=None):
if affine is None:
affine = [[0.,0.,0.,1.,1.] for _ in images]
if random.random() > 0.5:
for i, image in enumerate(images):
affine[i][4] *= -1
images[i] = F.hflip(image)
return images, mask, affine
class AffineIdentity(object):
def __call__(self, images, masks, affine=None):
if affine is None:
affine = [[0.,0.,0.,1.,1.] for _ in images]
return images, masks, affine
class MaskRandScaleCrop(object):
def __init__(self, scale_from, scale_to):
self.scale_from = scale_from
self.scale_to = scale_to
#assert scale_from >= 1., "Zooming in is not supported yet"
def get_scale(self):
return random.uniform(self.scale_from, self.scale_to)
def get_params(self, h, w, new_scale):
# generating random crop
# preserves aspect ratio
new_h = int(new_scale * h)
new_w = int(new_scale * w)
# generating
if new_scale <= 1.:
assert w >= new_w and h >= new_h, "{} vs. {} | {} / {}".format(w, new_w, h, new_h)
i = random.randint(0, h - new_h)
j = random.randint(0, w - new_w)
else:
assert w <= new_w and h <= new_h, "{} vs. {} | {} / {}".format(w, new_w, h, new_h)
i = random.randint(h - new_h, 0)
j = random.randint(w - new_w, 0)
return i, j, new_h, new_w
def __call__(self, images, masks, affine=None):
if affine is None:
affine = [[0.,0.,0.,1.,1.] for _ in images]
W, H = images[0].size
i2 = H / 2
j2 = W / 2
masks_new = []
# one crop for all
s = self.get_scale()
ii, jj, h, w = self.get_params(H, W, s)
# displacement of the centre
dy = ii + h / 2 - i2
dx = jj + w / 2 - j2
for k, image in enumerate(images):
affine[k][0] = dy
affine[k][1] = dx
affine[k][3] = 1 / s # scale
if s <= 1.:
assert ii >= 0 and jj >= 0
# zooming in
image_crop = F.crop(image, ii, jj, h, w)
images[k] = image_crop.resize((W, H), Image.BILINEAR)
mask_crop = F.crop(masks[k], ii, jj, h, w)
masks_new.append(mask_crop.resize((W, H), Image.NEAREST))
else:
assert ii <= 0 and jj <= 0
# zooming out
pad_l = abs(jj)
pad_r = w - W - pad_l
pad_t = abs(ii)
pad_b = h - H - pad_t
image_pad = F.pad(image, (pad_l, pad_t, pad_r, pad_b))
images[k] = image_pad.resize((W, H), Image.BILINEAR)
mask_pad = F.pad(masks[k], (pad_l, pad_t, pad_r, pad_b), 1)
masks_new.append(mask_pad.resize((W, H), Image.NEAREST))
return images, masks, affine
class MaskScaleSmallest(object):
def __init__(self, smallest_range):
self.size = smallest_range
def __call__(self, images, masks):
assert len(images) > 0, "Non-empty array expected"
new_size = self.size[0] + int((self.size[1] - self.size[0]) * random.random())
w, h = images[0].size
aspect = w / h
if aspect > 1:
new_h = new_size
new_w = int(new_size * aspect)
else:
new_w = new_size
new_h = int(new_size / aspect)
new_size = (new_w, new_h)
for i, (image, mask) in enumerate(zip(images, masks)):
assert image.size == mask.size
assert image.size == (w, h)
images[i] = image.resize(new_size, Image.BILINEAR)
masks[i] = mask.resize(new_size, Image.NEAREST)
return images, masks
class MaskRandCrop:
def __init__(self, size, pad_if_needed=False):
self.size = size # (h, w)
self.pad_if_needed = pad_if_needed
def __pad(self, img, padding_mode='constant', fill=0):
# pad the width if needed
pad_width = self.size[1] - img.size[0]
pad_height = self.size[0] - img.size[1]
if self.pad_if_needed and (pad_width > 0 or pad_height > 0):
pad_l = max(0, pad_width // 2)
pad_r = max(0, pad_width - pad_l)
pad_t = max(0, pad_height // 2)
pad_b = max(0, pad_height - pad_t)
img = F.pad(img, (pad_l, pad_t, pad_r, pad_b), fill, padding_mode)
return img
def __call__(self, images, masks):
for i, (image, mask) in enumerate(zip(images, masks)):
images[i] = self.__pad(image)
masks[i] = self.__pad(mask, fill=1)
i, j, h, w = tf.RandomCrop.get_params(images[0], self.size)
for k, (image, mask) in enumerate(zip(images, masks)):
images[k] = F.crop(image, i, j, h, w)
masks[k] = F.crop(mask, i, j, h, w)
return images, masks
class MaskCenterCrop:
def __init__(self, size):
self.size = size # (h, w)
def __call__(self, images, masks):
for i, (image, mask) in enumerate(zip(images, masks)):
images[i] = F.center_crop(image, self.size)
masks[i] = F.center_crop(mask, self.size)
return images, masks
class MaskRandHFlip:
def __call__(self, images, masks):
if random.random() > 0.5:
for i, (image, mask) in enumerate(zip(images, masks)):
images[i] = F.hflip(image)
masks[i] = F.hflip(mask)
return images, masks
================================================
FILE: infer_vos.py
================================================
"""
Single-scale inference
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
import sys
import numpy as np
import imageio
import time
import torch.multiprocessing as mp
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from opts import get_arguments
from core.config import cfg, cfg_from_file, cfg_from_list
from models import get_model
from utils.timer import Timer
from utils.sys_tools import check_dir
from utils.palette_davis import palette as davis_palette
from torch.utils.data import DataLoader
from datasets.dataloader_infer import DataSeg
from labelprop.common import LabelPropVOS_CRW
# deterministic inference
from torch.backends import cudnn
cudnn.enabled = True
cudnn.benchmark = False
cudnn.deterministic = True
VERBOSE = True
def mask2rgb(mask, palette):
mask_rgb = palette(mask)
mask_rgb = mask_rgb[:,:,:3]
return mask_rgb
def mask_overlay(mask, image, palette):
"""Creates an overlayed mask visualisation"""
mask_rgb = mask2rgb(mask, palette)
return 0.3 * image + 0.7 * mask_rgb
class ResultWriter:
def __init__(self, key, palette, out_path):
self.key = key
self.palette = palette
self.out_path = out_path
self.verbose = VERBOSE
def save(self, frames, masks_pred, masks_conf, masks_gt, flags, fn, seq_name):
subdir_vos = os.path.join(self.out_path, "{}_vos".format(self.key))
check_dir(subdir_vos, seq_name)
subdir_vis = os.path.join(self.out_path, "{}_vis".format(self.key))
check_dir(subdir_vis, seq_name)
for frame_id, mask in enumerate(masks_pred.split(1, 0)):
mask = mask[0].numpy().astype(np.uint8)
filepath = os.path.join(subdir_vos, seq_name, "{}.png".format(fn[frame_id][0]))
# saving only every 5th frame
if flags[frame_id] != 0:
imageio.imwrite(filepath, mask)
if self.verbose:
frame = frames[frame_id].numpy()
#mask_gt = masks_gt[frame_id].numpy().astype(np.uint8)
#masks = np.concatenate([mask, mask_gt], 1)
#frame = np.concatenate([frame, frame], 2)
frame = np.transpose(frame, [1,2,0])
overlay = mask_overlay(mask, frame, self.palette)
filepath = os.path.join(subdir_vis, seq_name, "{}.png".format(fn[frame_id][0]))
imageio.imwrite(filepath, (overlay * 255.).astype(np.uint8))
def convert_dict(state_dict):
new_dict = {}
for k,v in state_dict.items():
new_key = k.replace("module.", "")
new_dict[new_key] = v
return new_dict
def mask2tensor(mask, idx, num_classes=cfg.DATASET.NUM_CLASSES):
h,w = mask.shape
mask_t = torch.zeros(1,num_classes,h,w)
mask_t[0, idx] = mask
return mask_t
def configure_tracks(masks_gt, tracks, num_objects):
"""Selecting masks for initialisation
Args:
masks_gt: [T,H,W]
tracks: [T,2]
"""
init_masks = {}
# we always have first mask
# if there are no instances, it will be simply zero
H,W = masks_gt[0].shape[-2:]
init_masks[0] = torch.zeros(1, cfg.DATASET.NUM_CLASSES, H, W)
for oid in range(cfg.DATASET.NUM_CLASSES):
t = tracks[oid].item()
if not t in init_masks:
init_masks[t] = mask2tensor(masks_gt[oid], oid)
else:
init_masks[t] += mask2tensor(masks_gt[oid], oid)
return init_masks
def make_onehot(mask, HW):
# convert mask tensor with probabilities to a one-hot tensor
b,c,h,w = mask.shape
mask_up = F.interpolate(mask, HW, mode="bilinear", align_corners=True)
one_hot = torch.zeros_like(mask_up)
one_hot.scatter_(1, mask_up.argmax(1, keepdim=True), 1)
one_hot = F.interpolate(one_hot, (h,w), mode="bilinear", align_corners=True)
return one_hot
def scale_smallest(frame, a):
H,W = frame.shape[-2:]
s = a / min(H, W)
h, w = int(s * H), int(s * W)
return F.interpolate(frame, (h, w), mode="bilinear", align_corners=True)
def valid_mask(mask):
"""From a tensor [1,C,h,w]
create [1,C,1,1] 0/1 mask saying which IDs are present"""
B,C,h,w = mask.shape
valid = mask.flatten(2,3).sum(-1) > 0
valid = valid.type_as(mask).view(B,C,1,1)
return valid
def merge_mask_ids(masks, key0):
merged_mask = torch.zeros_like(masks[key0])
for tt, mask in masks.items():
merged_mask[:,1:] += mask[:,1:]
probs, ids = merged_mask.max(1, keepdim=True)
merged_mask.zero_()
merged_mask.scatter(1, ids, probs)
merged_mask[:, :1] = 1 - probs
return merged_mask
def step_seg(cfg, net, labelprop, frames, mask_init):
# dense tracking: start from the 1st frame
# keep track of new objects
T = frames.shape[0]
frames = frames.cuda()
# scale smallest
if cfg.TEST.INPUT_SIZE > 0:
frames = scale_smallest(frames, cfg.TEST.INPUT_SIZE)
fetch = {"res3": lambda x: x[0], \
"res4": lambda x: x[1], \
"key": lambda x: x[2]}
for t in mask_init.keys():
mask_init[t] = mask_init[t].cuda()
H,W = mask_init[0].shape[-2:]
scale_as = lambda x, y: F.interpolate(x, y.shape[-2:], mode="bilinear", align_corners=True)
scale = lambda x, hw: F.interpolate(x, hw, mode="bilinear", align_corners=True)
# context (cxt) will maintain
# the reference frame
ref_embd = {} # context embeddings
ref_masks = {}
ref_valid = {}
# we will also keep a bunch
# of previous frames
prev_embd = None
prev_masks = None
all_masks = []
all_masks_conf = []
def add_result(mask):
mask_up = scale(mask, (H, W))
nxt_masks_conf, nxt_masks_id = mask_up.max(1)
all_masks.append(nxt_masks_id.cpu())
all_masks_conf.append(nxt_masks_conf.cpu())
# initialising
embd0 = net(frames[:1], embd_only=True, norm=True)
embd0 = fetch[cfg.TEST.KEY](embd0)
mask0 = scale_as(mask_init[0], embd0)
add_result(mask0)
ref_embd[0] = {0: embd0}
ref_masks[0] = {0: mask0}
ref_valid[0] = valid_mask(mask0) # [x,c,1,1]
# add this to the reference context
# if there are objects
ref_index = []
if mask_init[0].sum() > 0:
ref_index = [0]
print(">", end='')
for t in range(1, T):
print(".", end='')
sys.stdout.flush()
# next frame
frames_batch = frames[t:t+1]
# source forward pass
nxt_embds = net(frames_batch, embd_only=True, norm=True)
# fetching the feature
nxt_embd = fetch[cfg.TEST.KEY](nxt_embds)
ref_t = [0] if len(ref_index) == 0 else ref_index
# for each reference mask
# we will create own context, then
# propagate the labels and merge the result
nxt_masks = {}
for t0 in ref_t:
cxt_index = labelprop.context_index(t0, t)
cxt_embd = [ref_embd[t0][j] for j in cxt_index]
cxt_masks = [ref_masks[t0][j] for j in cxt_index]
nxt_masks[t0] = labelprop.predict(cxt_embd, cxt_masks, nxt_embd, cxt_index, t)
# merging all the masks
nxt_mask = sum([ref_valid[tt] * nxt_masks[tt] for tt in nxt_masks.keys()])
#nxt_mask = merge_mask_ids(nxt_masks, ref_t[0])
if t in mask_init: # not t >= 0
print("Adding GT mask t = ", t)
# adding the initial mask if just appeared
mask_init_dn = scale_as(mask_init[t], nxt_embd)
mask_init_dn_s = mask_init_dn.sum(1, keepdim=True)
nxt_mask = (1 - mask_init_dn_s) * nxt_mask + mask_init_dn_s * mask_init_dn
# adding to context
ref_embd[t] = {}
ref_masks[t] = {}
ref_valid[t] = valid_mask(mask_init[t])
ref_index.append(t)
add_result(nxt_mask)
ref_t = [0] if len(ref_index) == 0 else ref_index
#
# updating the context
# two parts: for every initial index, keep first N and last M frames
for t0 in ref_t:
ref_embd[t0][t] = nxt_embd.clone()
ref_masks[t0][t] = nxt_mask.clone()
index_short = labelprop.context_long(t0, t)
tsteps = list(ref_embd[t0].keys())
for tt in tsteps:
if t - tt > cfg.TEST.CXT_SIZE and not tt in index_short:
del ref_embd[t0][tt]
del ref_masks[t0][tt]
print('<')
masks_pred = torch.cat(all_masks, 0)
masks_pred_conf = torch.cat(all_masks_conf, 0)
return masks_pred, masks_pred_conf
if __name__ == '__main__':
# loading the model
args = get_arguments(sys.argv[1:])
# reading the config
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
# initialising the dirs
check_dir(args.mask_output_dir, "{}_vis".format(cfg.TEST.KEY))
check_dir(args.mask_output_dir, "{}_vos".format(cfg.TEST.KEY))
# Loading the model
model = get_model(cfg, remove_layers=cfg.MODEL.REMOVE_LAYERS)
labelprop = LabelPropVOS_CRW(cfg)
if not os.path.isfile(args.resume):
print("[W]: ", "Snapshot not found: {}".format(args.resume))
print("[W]: Using a random model")
else:
state_dict = convert_dict(torch.load(args.resume)["model"])
try:
model.load_state_dict(state_dict)
except Exception as e:
print("Error while loading the snapshot:\n", str(e))
print("Resuming...")
model.load_state_dict(state_dict, strict=False)
for p in model.parameters():
p.requires_grad = False
# setting the evaluation mode
model.eval()
model = model.cuda()
dataset = DataSeg(cfg, args.infer_list)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, \
drop_last=False) #, num_workers=args.workers)
palette = dataloader.dataset.get_palette()
timer = Timer()
N = len(dataloader)
pool = mp.Pool(processes=args.workers)
writer = ResultWriter(cfg.TEST.KEY, davis_palette, args.mask_output_dir)
for iter, batch in enumerate(dataloader):
frames_orig, frames, masks_gt, tracks, num_ids, fns, flags, seq_name = batch
print("Sequence {:02d} | {}".format(iter, seq_name[0]))
masks_gt = masks_gt.flatten(0,1)
frames_orig = frames_orig.flatten(0,1)
frames = frames.flatten(0,1)
tracks = tracks.flatten(0,1)
flags = flags.flatten(0,1)
init_masks = configure_tracks(masks_gt, tracks, num_ids[0])
assert 0 in init_masks, "initial frame has no instances"
with torch.no_grad():
masks_pred, masks_conf = step_seg(cfg, model, labelprop, frames, init_masks)
frames_orig = dataset.denorm(frames_orig)
pool.apply_async(writer.save, args=(frames_orig, masks_pred.cpu(), masks_conf.cpu(), masks_gt.cpu(), flags, fns, seq_name[0]))
timer.stage("Inference completed")
pool.close()
pool.join()
================================================
FILE: labelprop/common.py
================================================
"""
Based on the inference routines from Jabri et al., (2020)
Credit: https://github.com/ajabri/videowalk.git
License: MIT
"""
import sys
import torch
from labelprop.crw import MaskedAttention
from labelprop.crw import mem_efficient_batched_affinity as batched_affinity
class LabelPropVOS(object):
def context_long(self):
"""Returns indices of the timesteps
for long-term memory
"""
raise NotImplementedError()
def context_short(self, t):
"""
Args:
t: current timestep
Returns:
list: indices of timesteps
for the context
"""
raise NotImplementedError()
def predict(self, feats, masks, curr_feat):
"""
Args:
feats [C,K,h,w]: context features
masks [C,M,h,w]: context masks
curr_feat [1,K,h,w]: current frame features
Returns:
mask [1,M,h,w]: current frame mask
"""
raise NotImplementedError()
class LabelPropVOS_CRW(LabelPropVOS):
def __init__(self, cfg):
self.cxt_size = cfg.TEST.CXT_SIZE
self.radius = cfg.TEST.RADIUS
self.temperature = cfg.TEST.TEMP
self.topk = cfg.TEST.KNN
self.mask = None
self.mask_hw = None
def context_long(self, t0, t):
return [t0]
def context_short(self, t0, t):
to_t = t
from_t = to_t - self.cxt_size
timesteps = [max(tt, t0) for tt in range(from_t, to_t)]
return timesteps
def context_index(self, t0, t):
index_short = self.context_short(t0, t)
index_long = self.context_long(t0, t)
cxt_index = index_long + index_short
return cxt_index
def predict(self, feats, masks, curr_feat, ref_index=None, t=None):
"""
Args:
feats: list of C [1,K,h,w] context features
masks: list of C [1,M,h,w] context masks
curr_feat: [1,K,h,w] current frame features
ref_index: C indices of context frames
t: current frame time step
Returns:
mask [1,M,h,w]: current frame mask
"""
dev = curr_feat.device
h, w = curr_feat.shape[-2:]
# [BC+N,M,h,w] -> [BC+N,h,w,M]
ctx_lbls = torch.cat(masks, 0).permute([0,2,3,1])
# keys [1,K,C,h,w]: context features
keys = torch.stack(feats, 2)[:, :, None]
# query: [1,K,1,h,w]: reference feature
query = curr_feat[:, :, None]
if self.mask is None or self.mask_hw != (h, w):
# Make spatial radius mask TODO use torch.sparse
restrict = MaskedAttention(self.radius, flat=False)
D = restrict.mask(h, w)[None]
D = D.flatten(-4, -3).flatten(-2)
D[D==0] = -1e10; D[D==1] = 0
self.mask = D.to(dev)
self.mask_hw = (h, w)
# Flatten source frame features to make context feature set
keys, query = keys.flatten(-2), query.flatten(-2)
long_mem = [0]
Ws, Is = batched_affinity(query, keys, self.mask, \
self.temperature, self.topk, long_mem)
# Soft labels of source nodes
ctx_lbls = ctx_lbls.flatten(0, 2).transpose(0, 1)
# Weighted sum of top-k neighbours (Is is index, Ws is weight)
pred = (ctx_lbls[:, Is[0].to(dev)] * Ws[0][None].to(dev)).sum(1)
pred = pred.view(-1, h, w)
pred = pred.permute(1,2,0)
# Adding Predictions
pred = pred.permute([2,0,1])[None, ...]
return pred
================================================
FILE: labelprop/crw.py
================================================
"""
Inference routines from Jabri et al., (2020)
Credit: https://github.com/ajabri/videowalk.git
License: MIT
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class CRW(object):
"""Propagation algorithm"""
def __init__(self, cfg):
self.n_context = cfg.CXT_SIZE
self.radius = cfg.RADIUS
self.temperature = cfg.TEMP
self.topk = cfg.KNN
print("Inference Opts:")
print("Context size: {}".format(self.n_context))
print(" Radius: {}".format(self.radius))
print(" Temp: {}".format(self.temperature))
print(" TopK: {}".format(self.topk))
# always keeping the first frame
# TODO: move to cfg
self.long_mem = [0]
# for bwd-compatibility
self.norm_mask = False
self.mask = None
self.mask_hw = None
def _prep_context(self, feats, lbls, hw):
"""Adjust for context
Args:
lbls: [N,M,H,W]
"""
lbls = F.interpolate(lbls, hw, mode="bilinear", align_corners=True)
ref = lbls[:1].expand(self.n_context,-1,-1,-1)
lbls = torch.cat([ref, lbls], 0)
fref = feats[:1].expand(self.n_context,-1,-1,-1)
fref = torch.cat([fref, feats], 0)
return fref, lbls
def forward(self, feats, lbls):
"""Propagate features
Args:
feats: [N,K,h,w]
lbls: [N,M,H,W]
"""
N,K,h,w = feats.shape
M,H,W = lbls.shape[-3:]
feats, lbls = self._prep_context(feats, lbls, (h,w))
# [N,K,h,w] -> [1,K,N,h,w]
feats = feats.permute([1,0,2,3])
# singleton for compatibility
feats = feats[None,...]
# [BC+N,M,h,w] -> [BC+N,h,w,M]
lbls = lbls.permute([0,2,3,1])
n_context = self.n_context
torch.cuda.empty_cache()
# Prepare source (keys) and target (query) frame features
key_indices = context_index_bank(n_context, self.long_mem, N)
key_indices = torch.cat(key_indices, dim=-1)
keys = feats[:, :, key_indices]
query = feats[:, :, n_context:]
# Make spatial radius mask TODO use torch.sparse
if self.mask is None or self.mask_hw != (h, w):
restrict = MaskedAttention(self.radius, flat=False)
D = restrict.mask(h, w)[None]
D = D.flatten(-4, -3).flatten(-2)
D[D==0] = -1e10; D[D==1] = 0
self.mask = D.cuda()
self.mask_hw = (h, w)
# Flatten source frame features to make context feature set
keys, query = keys.flatten(-2), query.flatten(-2)
Ws, Is = mem_efficient_batched_affinity(query, keys, self.mask, \
self.temperature, self.topk, self.long_mem)
##################################################################
# Propagate Labels and Save Predictions
###################################################################
masks_idx = torch.LongTensor(N,H,W)
masks_prob = torch.FloatTensor(N,H,W)
for t in range(key_indices.shape[0]):
# Soft labels of source nodes
ctx_lbls = lbls[key_indices[t]].cuda()
ctx_lbls = ctx_lbls.flatten(0, 2).transpose(0, 1)
# Weighted sum of top-k neighbours (Is is index, Ws is weight)
pred = (ctx_lbls[:, Is[t]] * Ws[t][None].cuda()).sum(1)
pred = pred.view(-1, h, w)
pred = pred.permute(1,2,0)
if t > 0:
lbls[t + n_context] = pred
else:
pred = lbls[0]
lbls[t + n_context] = pred
if self.norm_mask:
pred[:, :, :] -= pred.min(-1)[0][:, :, None]
pred[:, :, :] /= pred.max(-1)[0][:, :, None]
# Adding Predictions
pred_ = pred.permute([2,0,1])[None, ...]
pred_up = F.interpolate(pred_, (H,W), mode="bilinear", align_corners=True)
pred_up = pred_up[0].cpu()
masks_idx[t] = pred_up.argmax(0)
masks_prob[t] = pred_up[1]
out = {}
out["masks_pred_idx"] = masks_idx
out["masks_pred_conf"] = masks_prob
return out
def context_index_bank(n_context, long_mem, N):
'''
Construct bank of source frames indices, for each target frame
'''
ll = [] # "long term" context (i.e. first frame)
for t in long_mem:
assert 0 <= t < N, 'context frame out of bounds'
idx = torch.zeros(N, 1).long()
if t > 0:
idx += t + (n_context+1)
idx[:n_context+t+1] = 0
ll.append(idx)
# "short" context
ss = [(torch.arange(n_context)[None].repeat(N, 1) + torch.arange(N)[:, None])[:, :]]
return ll + ss
def batched_affinity(query, keys, mask, temperature, topk, long_mem, device):
'''
Mini-batched computation of affinity, for memory efficiency
(less aggressively mini-batched)
'''
A = torch.einsum('ijklm,ijkn->iklmn', keys, query)
# Mask
A[0, :, len(long_mem):] += mask.to(device)
_, N, T, h1w1, hw = A.shape
A = A.view(N, T*h1w1, hw)
A /= temperature
weights, ids = torch.topk(A, topk, dim=-2)
weights = F.softmax(weights, dim=-2)
Ws = [w for w in weights]
Is = [ii for ii in ids]
return Ws, Is
def mem_efficient_batched_affinity(query, keys, mask, temperature, topk, long_mem):
'''
Mini-batched computation of affinity, for memory efficiency
'''
bsize, pbsize = 2, 100
Ws, Is = [], []
for b in range(0, keys.shape[2], bsize):
_k, _q = keys[:, :, b:b+bsize].cuda(), query[:, :, b:b+bsize].cuda()
w_s, i_s = [], []
for pb in range(0, _k.shape[-1], pbsize):
A = torch.einsum('ijklm,ijkn->iklmn', _k, _q[..., pb:pb+pbsize])
A[0, :, len(long_mem):] += mask[..., pb:pb+pbsize]
_, N, T, h1w1, hw = A.shape
A = A.view(N, T*h1w1, hw)
A /= temperature
weights, ids = torch.topk(A, topk, dim=-2)
weights = F.softmax(weights, dim=-2)
w_s.append(weights)
i_s.append(ids)
weights = torch.cat(w_s, dim=-1)
ids = torch.cat(i_s, dim=-1)
Ws += [w for w in weights]
Is += [ii for ii in ids]
return Ws, Is
class MaskedAttention(nn.Module):
'''
A module that implements masked attention based on spatial locality
TODO implement in a more efficient way (torch sparse or correlation filter)
'''
def __init__(self, radius, flat=True):
super(MaskedAttention, self).__init__()
self.radius = radius
self.flat = flat
self.masks = {}
self.index = {}
def mask(self, H, W):
if not ('%s-%s' %(H,W) in self.masks):
self.make(H, W)
return self.masks['%s-%s' %(H,W)]
def index(self, H, W):
if not ('%s-%s' %(H,W) in self.index):
self.make_index(H, W)
return self.index['%s-%s' %(H,W)]
def make(self, H, W):
if self.flat:
H = int(H**0.5)
W = int(W**0.5)
gx, gy = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
D = ( (gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2 ).float() ** 0.5
D = (D < self.radius)[None].float()
if self.flat:
D = self.flatten(D)
self.masks['%s-%s' %(H,W)] = D
return D
def flatten(self, D):
return torch.flatten(torch.flatten(D, 1, 2), -2, -1)
def make_index(self, H, W, pad=False):
mask = self.mask(H, W).view(1, -1).byte()
idx = torch.arange(0, mask.numel())[mask[0]][None]
self.index['%s-%s' %(H,W)] = idx
return idx
def forward(self, x):
H, W = x.shape[-2:]
sid = '%s-%s' % (H,W)
if sid not in self.masks:
self.masks[sid] = self.make(H, W).to(x.device)
mask = self.masks[sid]
return x * mask[0]
================================================
FILE: launch/infer_vos.sh
================================================
#!/bin/bash
#
# Arguments
#
# suffix for the output directory (see below)
VER=v01
# defines the path to the snapshot (see below)
EXP=ours
RUN_ID=final
# SNAPSHOT name. Note .pth will be attached
# See README.md to download these snapshots
SNAPSHOT=ytvos_e060_res4 # YouTube-VOS
#SNAPSHOT=trackingnet_e088_res4 # TrackingNet
#SNAPSHOT=oxuva_e430_res4 # OxUvA
#SNAPSHOT=kinetics_e026_res4 # Kinetics
# codename of the final output layer [res3|res4|key]
KEY=res4
#
# Changing the following is not necessary
#
DS=$1
case $DS in
ytvos)
echo "Test dataset: YouTube-VOS 2018 (val)"
FILELIST=filelists/val_ytvos2018_test
;;
davis)
echo "Test dataset: DAVIS-2017 (val)"
FILELIST=filelists/val_davis2017_test
;;
*)
echo "Dataset '$DS' not recognised. Should be one of [ytvos|davis]."
exit 1
;;
esac
# The config file is irrelevant
# since the inference parameters are
# always the same
CONFIG=configs/ytvos.yaml
OUTPUT_DIR=./results
EXTRA="$EXTRA --seed 0 --set TEST.KEY $KEY"
SAVE_ID=${RUN_ID}_${SNAPSHOT}_${VER}
SNAPSHOT_PATH=snapshots/${EXP}/${RUN_ID}/${SNAPSHOT}.pth
if [ ! -f $SNAPSHOT_PATH ]; then
echo "Snapshot $SNAPSHOT_PATH NOT found."
exit 1;
fi
#
# Code goes here
#
LISTNAME=`basename $FILELIST .txt`
SAVE_DIR=$OUTPUT_DIR/$EXP/$SAVE_ID/$LISTNAME
LOG_FILE=$OUTPUT_DIR/$EXP/$SAVE_ID/${LISTNAME}_${KEY}.log
NUM_THREADS=12
export OMP_NUM_THREADS=$NUM_THREADS
export MKL_NUM_THREADS=$NUM_THREADS
CMD="python infer_vos.py --cfg $CONFIG \
--exp $EXP \
--run $RUN_ID \
--resume $SNAPSHOT_PATH \
--infer-list $FILELIST \
--mask-output-dir $SAVE_DIR \
$EXTRA"
if [ ! -d $SAVE_DIR ]; then
echo "Creating directory: $SAVE_DIR"
mkdir -p $SAVE_DIR
else
echo "Saving to: $SAVE_DIR"
fi
git rev-parse HEAD > ${SAVE_DIR}.head
git diff > ${SAVE_DIR}.diff
echo $CMD > ${SAVE_DIR}.cmd
echo $CMD
nohup $CMD > $LOG_FILE 2>&1 &
sleep 1
tail -f $LOG_FILE
================================================
FILE: launch/train.sh
================================================
#!/bin/bash
# Set the following variables
# The tensorboard logging will be creating in logs/<EXP>/<EXP_ID>
# The snapshots will be saved in snapshots/<EXP>/<EXP_ID>
EXP=v01_vos
EXP_ID=v01_00_base
#
# No change are necessary starting here
#
DS=$1
SEED=32
case $DS in
oxuva)
echo "Train dataset: OxUvA"
TASK="OxUvA_all"
CFG=configs/oxuva.yaml
EXP_ID="OX_${EXP_ID}"
;;
ytvos)
echo "Train dataset: YouTube-VOS"
TASK="YTVOS"
CFG=configs/ytvos.yaml
EXP_ID="YT_${EXP_ID}"
;;
track)
echo "Train dataset: TrackingNet"
TASK="TrackingNet"
CFG=configs/tracknet.yaml
EXP_ID="TN_${EXP_ID}"
;;
kinetics)
echo "Train dataset: Kinetics-400"
CFG=configs/kinetics.yaml
EXP_ID="KT_${EXP_ID}"
;;
*)
echo "Dataset '$DS' not recognised. Should be one of [oxuva|ytvos|track|kinetics]."
exit 1
;;
esac
CURR_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
source $CURR_DIR/utils.bash
CMD="python train.py --cfg $CFG --exp $EXP --run $EXP_ID --seed $SEED"
LOG_DIR=logs/${EXP}
LOG_FILE=$LOG_DIR/${EXP_ID}.log
echo "LOG: $LOG_FILE"
check_rundir $LOG_DIR $EXP_ID
NUM_THREADS=12
export OMP_NUM_THREADS=$NUM_THREADS
export MKL_NUM_THREADS=$NUM_THREADS
echo $CMD
CMD_FILE=$LOG_DIR/${EXP_ID}.cmd
echo $CMD > $CMD_FILE
git rev-parse HEAD > $LOG_DIR/${EXP_ID}.head
git diff > $LOG_DIR/${EXP_ID}.diff
nohup $CMD > $LOG_FILE 2>&1 &
sleep 1
tail -f $LOG_FILE
================================================
FILE: launch/utils.bash
================================================
#!/bin/bash
check_rundir()
{
LOG_DIR="$1"
EXP_ID="$2"
if [ ! -d "$LOG_DIR" ]; then
echo "Creating directory $LOG_DIR"
mkdir -p $LOG_DIR
else
LOGD=$LOG_DIR/$EXP_ID
if [ -d "$LOGD" ]; then
echo "Directory $LOGD already exists."
read -p "Do you want to remove the log files?: " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]
then
exit;
else
rm -rf $LOGD
fi
fi
fi
}
================================================
FILE: models/__init__.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
from .resnet18 import resnet18
from .net import Net
from .framework import Framework
def get_model(cfg, *args, **kwargs):
backbones = {
'resnet18': resnet18
}
def create_net():
backbone = backbones[cfg.MODEL.ARCH.lower()](*args, **kwargs)
return Net(cfg, backbone)
net = create_net()
return Framework(cfg, net)
================================================
FILE: models/base.py
================================================
"""
Base class for network models
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseNet(nn.Module):
_trainable = (nn.Linear, nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d, nn.SyncBatchNorm)
_batchnorm = (nn.BatchNorm2d, nn.SyncBatchNorm, nn.GroupNorm)
def __init__(self):
super().__init__()
# we may want a different learning rate
# for new layers
self.from_scratch_layers = []
# we may want to freeze some layers
self.not_training = []
# we may want to freeze BN (means/stds only)
self.bn_freeze = []
def lr_mult(self):
"""Learning rate multiplier for weights.
Returns: [old, new]"""
return 1., 1.
def lr_mult_bias(self):
"""Learning rate multiplier for bias.
Returns: [old, new]"""
return 2., 2.
def _is_learnable(self, layer):
return isinstance(layer, BaseNet._trainable)
def _from_scratch(self, net, ignore=[]):
for layer in net.modules():
if self._is_learnable(layer):
self.from_scratch_layers.append(layer)
def _freeze_bn(self, net, ignore=[]):
"""Add layers to use in .eval() mode only"""
for layer in net.modules():
if isinstance(layer, BaseNet._batchnorm) and \
not layer in ignore:
assert hasattr(layer, "eval") and callable(layer.eval)
print("Freezing ", layer)
self.bn_freeze.append(layer)
print("Frozen BN: ", len(self.bn_freeze))
def _fix_bn(self, layer):
if isinstance(layer, nn.BatchNorm2d):
self.not_training.append(layer)
elif isinstance(layer, nn.Module):
for c in layer.children():
self._fix_bn(c)
def __set_grad_mode(self, layer, mode, only_type=None):
if hasattr(layer, "weight"):
if only_type is None or isinstance(layer, only_type):
layer.weight.requires_grad = mode
if hasattr(layer, "bias") and not layer.bias is None:
if only_type is None or isinstance(layer, only_type):
layer.bias.requires_grad = mode
if isinstance(layer, nn.Module):
for c in layer.children():
self.__set_grad_mode(c, mode)
def train(self, mode=True):
super().train(mode)
# some layers have to be frozen
for layer in self.not_training:
self.__set_grad_mode(layer, False)
for layer in self.bn_freeze:
assert hasattr(layer, "eval") and callable(layer.eval)
layer.eval()
def parameter_groups(self, base_lr, wd):
w_old, w_new = self.lr_mult()
b_old, b_new = self.lr_mult_bias()
groups = ({"params": [], "weight_decay": wd, "lr": w_old*base_lr}, # weight learning
{"params": [], "weight_decay": 0.0, "lr": b_old*base_lr}, # bias learning
{"params": [], "weight_decay": wd, "lr": w_new*base_lr}, # weight finetuning
{"params": [], "weight_decay": 0.0, "lr": b_new*base_lr}) # bias finetuning
for m in self.modules():
if not self._is_learnable(m):
if hasattr(m, "weight") or hasattr(m, "bias"):
print("Skipping layer with parameters: ", m)
continue
if not m.weight is None and m.weight.requires_grad:
if m in self.from_scratch_layers:
groups[2]["params"].append(m.weight)
else:
groups[0]["params"].append(m.weight)
elif not m.weight is None:
print("Skipping W: ", m, m.weight.size())
if m.bias is not None and m.bias.requires_grad:
if m in self.from_scratch_layers:
groups[3]["params"].append(m.bias)
else:
groups[1]["params"].append(m.bias)
elif m.bias is not None:
print("Skipping b: ", m, m.bias.size())
return groups
@staticmethod
def _resize_as(x, y):
return F.interpolate(x, y.size()[-2:], mode="bilinear", align_corners=True)
================================================
FILE: models/framework.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.base import BaseNet
class Framework(BaseNet):
def __init__(self, cfg, net):
super(Framework, self).__init__()
self.cfg = cfg
self.fast_net = net
self.eye = None
def parameter_groups(self, base_lr, wd):
return self.fast_net.parameter_groups(base_lr, wd)
def _align(self, x, t):
tf = F.affine_grid(t, size=x.size(), align_corners=False)
return F.grid_sample(x, tf, align_corners=False, mode="nearest")
def _key_val(self, ctr, q):
"""
Args:
ctr: [N,K]
q: [BHW,K]
Returns:
val: [BHW,N]
"""
# [BHW,K] x [N,K].t -> [BHWxN]
vals = torch.mm(q, ctr.t()) # [BHW,N]
# normalising attention
return vals / self.cfg.TEST.TEMP
def _sample_index(self, x, T, N):
"""Sample indices of the anchors
Args:
x: [BT,K,H,W]
Returns:
index: [B,N*N,K]
"""
BT,K,H,W = x.shape
B = x.view(-1,T,K,H*W).shape[0]
# sample indices from a uniform grid
xs, ys = W // N, H // N
x_sample = torch.arange(0, W, xs).view(1, 1, N)
y_sample = torch.arange(0, H, ys).view(1, N, 1)
# Random offsets
# [B x 1 x N]
x_sample = x_sample + torch.randint(0, xs, (B, 1, 1))
# [B x N x 1]
y_sample = y_sample + torch.randint(0, ys, (B, 1, 1))
# batch index
# [B x N x N]
hw_index = torch.LongTensor(x_sample + y_sample * W)
return hw_index
def _sample_from(self, x, index, T, N):
"""Gather the features based on the index
Args:
x: [BT,K,H,W]
index: [B,N,N] defines the indices of NxN grid for a single
frame in each of B video clips
Returns:
anchors: [BNN,K] sampled features given by index from x
"""
BT,K,H,W = x.shape
x = x.view(-1,T,K,H*W)
B = x.shape[0]
# > [B,T,K,HW] > [B,T,HW,K] > [B,THW,K]
x = x.permute([0,1,3,2]).reshape(B,-1,K)
# every video clip will have the same samples
# on the grid
# [B x N x N] -> [B x N*N x 1] -> [B x N*N x K]
index = index.view(B,-1,1).expand(-1,-1,K)
# selecting from the uniform grid
y = x.gather(1, index.to(x.device))
# [BNN,K]
return y.flatten(0,1)
def _mark_from(self, x, index, T, N, fill_value=0):
"""This is analogous to _sample_from except that
here we simply "mark" the sampled positions in the tensor
Used for visualisation only.
Since it is a binary mask, K == 1
Args:
x: [BT,1,H,W] binary mask
index: [B,N,N] defines the indices of NxN grid for a single
frame in each of B video clips
Returns:
y: [BT,1,H,W] marked sample positions
"""
BT,K,H,W = x.shape
assert K == 1, "Expected binary mask"
x = x.view(-1,T,K,H*W)
B = x.shape[0]
# > [B,T,K,HW] > [B,T,HW,K] > [B,THW,K]
x = x.permute([0,1,3,2]).reshape(B,-1,K)
# every video clip will have the same samples
# on the grid
# [B x N x N] -> [B x N*N x 1] -> [B x N*N x K]
index = index.view(B,-1,1).expand(-1,-1,K)
# selecting from the uniform grid
# [B x T*H*W x K]
y = x.scatter(1, index.to(x.device), fill_value)
# [B x T*H*W x K] -> [BT x K x H x W]
return y.view(-1,H*W,K).permute([0,2,1]).view(-1,K,H,W)
def _cluster_grid(self, k1, k2, aff1, aff2, T, index=None):
""" Selecting clusters within a sequence
Args:
k1: [BT,K,H,W]
k2: [BT,K,H,W]
"""
BT,K,H,W = k1.shape
assert BT % T == 0, "Batch not divisible by sequence length"
B = BT // T
# N = [G x G]
N = self.cfg.MODEL.GRID_SIZE ** 2
# [BT,K,H,W] -> [BTHW,K]
flatten = lambda x: x.flatten(2,3).permute([0,2,1]).flatten(0,1)
# [BTHW,BN] -> [BT,BN,H,W]
def unflatten(x, aff=None):
x = x.view(BT,H*W,-1).permute([0,2,1]).view(BT,-1,H,W)
if aff is None:
return x
return self._align(x, aff)
index = self._sample_index(k1, T, N = self.cfg.MODEL.GRID_SIZE)
query1 = self._sample_from(k1, index, T, N = self.cfg.MODEL.GRID_SIZE)
"""Computing the distances and pseudo labels"""
# [BTHW,K]
k1_ = flatten(k1)
k2_ = flatten(k2)
# [BTHW,BN] -> [BTHW,BN] -> [BT,BN,H,W]
vals_soft = unflatten(self._key_val(query1, k1_), aff1)
vals_pseudo = unflatten(self._key_val(query1, k2_), aff2)
# [BT,BN,H,W]
probs_pseudo = self._pseudo_mask(vals_pseudo, T)
probs_pseudo2 = self._pseudo_mask(vals_soft, T)
pseudo = probs_pseudo.argmax(1)
pseudo2 = probs_pseudo2.argmax(1)
# mask
def grid_mask():
grid_mask = torch.ones(BT,1,H,W).to(pseudo.device)
return self._mark_from(grid_mask, index, T, N = self.cfg.MODEL.GRID_SIZE)
return vals_soft, pseudo, index, [vals_pseudo, pseudo2, grid_mask]
# sampling affinity
def _aff_sample(self, k1, k2, T):
BT,K,h,w = k1.size()
B = BT // T
hw = h*w
def gen(query):
grid_mask = torch.ones(B,1,hw).to(k1.device)
# generating random indices
indices = torch.randint(0, hw, (B,1,1)).to(k1.device)
grid_mask.scatter_(2, indices, 0)
# [B,K,H,W] -> [B,K,1]
query_ = query[::T].view(B,K,-1).gather(2, indices.expand(-1,K,-1))
def aff(keys):
k = keys.view(B,T,K,-1)
# [B,T,K,HW] x [B,1,K,HW] -> [B,T,HW]
aff = (k * query_[:,None,:,:]).sum(2)
return (aff + 1) / 2
aff1 = aff(k1)
aff2 = aff(k2)
return grid_mask.view(B,h,w), aff1.view(BT,h,w), aff2.view(BT,h,w)
grid_mask1, aff1_1, aff1_2 = gen(k1)
grid_mask2, aff2_1, aff2_2 = gen(k2)
return grid_mask1, aff1_1, aff1_2, \
grid_mask2, aff2_1, aff2_2
def _pseudo_mask(self, logits, T):
BT,K,h,w = logits.shape
assert BT % T == 0, "Batch not divisible by sequence length"
B = BT // T
# N = [G x G]
N = self.cfg.MODEL.GRID_SIZE ** 2
# generating a pseudo label
# first we need to mask out the affinities across the batch
if self.eye is None or self.eye.shape[0] != B*T \
or self.eye.shape[1] != B*N:
eye = torch.eye(B)[:,:,None].expand(-1,-1,N).reshape(B,-1)
eye = eye.unsqueeze(1).expand(-1,T,-1).reshape(B*T, B*N, 1, 1)
self.eye = eye.to(logits.device)
probs = F.softmax(logits, 1)
return probs * self.eye
def _ref_loss(self, x, y, N = 4):
B,_,h,w = x.shape
index = self._sample_index(x, T=1, N=N)
x1 = self._sample_from(x, index, T=1, N=N)
y1 = self._sample_from(y, index, T=1, N=N)
logits = torch.mm(x1, y1.t()) / self.cfg.TEST.TEMP
labels = torch.arange(logits.size(1)).to(logits.device)
return F.cross_entropy(logits, labels)
def _ce_loss(self, x, pseudo_map, T, eps=1e-5):
error_map = F.cross_entropy(x, pseudo_map, reduction="none", ignore_index=-1)
BT,h,w = error_map.shape
errors = error_map.view(-1,T,h,w)
error_ref, error_t = errors[:,0], errors[:,1:]
return error_ref.mean(), error_t.mean(), error_map
def _forward_reg(self, frames2, norm):
losses = {}
if not self.cfg.TRAIN.STOP_GRAD:
k2, res3, res4 = self.fast_net(frames2, norm)
return k2, res3, res4, losses
training = self.fast_net.training
if self.cfg.TRAIN.BLOCK_BN:
self.fast_net.eval()
with torch.no_grad():
k2, res3, res4 = self.fast_net(frames2, norm)
if self.cfg.TRAIN.BLOCK_BN:
self.fast_net.train(training)
return k2, res3, res4, losses
def fetch_first(self, x1, x2, T):
assert x1.shape[1:] == x2.shape[1:]
c,h,w = x1.shape[1:]
x1 = x1.view(-1,T+1,c,h,w)
x2 = x2.view(-1,T-1,c,h,w)
x2 = torch.cat((x1[:,-1:], x2), 1)
x1 = x1[:,:-1]
return x1.flatten(0,1), x2.flatten(0,1)
def forward(self, frames, frames2=None, mask=None, T=None, affine=None, affine2=None, embd_only=False, norm=True, dbg=False):
"""Extract temporal correspondences
Args:
frames: [B,T,C,H,W]
Returns:
losses: a dictionary with the embedding loss
net_outs: feature embeddings
"""
# embedding for self-supervised learning
key1, res3, res4 = self.fast_net(frames, norm)
outs, losses = {}, {}
if embd_only: # only embedding
return res3, res4, key1
else:
key2, res3_2, res4_2, losses = self._forward_reg(frames2, norm)
# fetching the first frame from the second view
key1, key2 = self.fetch_first(key1, key2, T)
vals, pseudo, index, dbg_info = self._cluster_grid(key1, key2, affine, affine2, T)
vals_pseudo, pseudo2, grid_mask = dbg_info
key1_aligned = self._align(key1, affine)
key2_aligned = self._align(key2, affine2)
n_ref = self.cfg.MODEL.GRID_SIZE_REF
losses["cross_key"] = self._ref_loss(key1_aligned[::T], key2_aligned[::T], N = n_ref)
# losses
_, losses["temp"], outs["error_map"] = self._ce_loss(vals, pseudo, T)
# computing the main loss
losses["main"] = self.cfg.MODEL.CE_REF * losses["cross_key"] + losses["temp"]
if dbg:
vals = F.softmax(vals, 1)
vals_pseudo = F.softmax(vals_pseudo, 1)
frames, frames2 = self.fetch_first(frames, frames2, T)
outs["frames_orig"] = frames
outs["frames"] = self._align(frames, affine)
outs["frames2"] = self._align(frames2, affine2)
outs["map_soft"] = vals
outs["map"] = pseudo
outs["map_target_soft"] = vals_pseudo
outs["map_target"] = pseudo2
outs["grid_mask"] = grid_mask()
outs["aff_mask1"], outs["aff11"], outs["aff12"], \
outs["aff_mask2"], outs["aff21"], outs["aff22"] = self._aff_sample(key1, key2, T)
return losses, outs
================================================
FILE: models/net.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.base import BaseNet
class MLP(nn.Sequential):
def __init__(self, n_in, n_out):
super().__init__()
self.add_module("conv1", nn.Conv2d(n_in, n_in, 1, 1))
self.add_module("bn1", nn.BatchNorm2d(n_in))
self.add_module("relu", nn.ReLU(True))
self.add_module("conv2", nn.Conv2d(n_in, n_out, 1, 1))
class Net(BaseNet):
def __init__(self, cfg, backbone):
super(Net, self).__init__()
self.cfg = cfg
self.backbone = backbone
self.emb_q = MLP(backbone.fdim, cfg.MODEL.FEATURE_DIM)
def lr_mult(self):
"""Learning rate multiplier for weights.
Returns: [old, new]"""
return 1., 1.
def lr_mult_bias(self):
"""Learning rate multiplier for bias.
Returns: [old, new]"""
return 2., 2.
def forward(self, frames, norm=True):
"""Forward pass to extract projection and task features"""
# extracting the time dimension
res4, res3 = self.backbone(frames)
# B,K,H,W
query = self.emb_q(res4)
if norm:
query = F.normalize(query, p=2, dim=1)
res3 = F.normalize(res3, p=2, dim=1)
res4 = F.normalize(res4, p=2, dim=1)
return query, res3, res4
================================================
FILE: models/resnet18.py
================================================
"""
Based on Jabri et al., (2020)
Credit: https://github.com/ajabri/videowalk.git
License: MIT
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.resnet as torch_resnet
from torchvision.models.resnet import BasicBlock
class ResNet(torch_resnet.ResNet):
def __init__(self, *args, **kwargs):
super(ResNet, self).__init__(*args, **kwargs)
def filter_layers(self, x):
return [l for l in x if getattr(self, l) is not None]
def remove_layers(self, remove_layers=[]):
# Remove extraneous layers
remove_layers += ['fc', 'avgpool']
for layer in self.filter_layers(remove_layers):
setattr(self, layer, None)
def modify(self):
# Set stride of layer3 and layer 4 to 1 (from 2)
for layer in self.filter_layers(['layer3']):
for m in getattr(self, layer).modules():
if isinstance(m, torch.nn.Conv2d):
m.stride = (1, 1)
for layer in self.filter_layers(['layer4']):
for m in getattr(self, layer).modules():
if isinstance(m, torch.nn.Conv2d):
m.stride = (1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = x if self.maxpool is None else self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x3 = self.layer3(x)
x4 = self.layer4(x3)
return x4, x3
def _resnet(arch, block, layers, pretrained, **kwargs):
model = ResNet(block, layers, **kwargs)
return model
def resnet18(pretrained='', remove_layers=[], train=True, **kwargs):
model = _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs)
model.modify()
model.remove_layers(remove_layers)
setattr(model, "fdim", 512)
return model
================================================
FILE: opts.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
from __future__ import print_function
import os
import torch
import argparse
from core.config import cfg
def add_global_arguments(parser):
#
# Model details
#
parser.add_argument("--snapshot-dir", type=str, default='./snapshots',
help="Where to save snapshots of the model.")
parser.add_argument("--logdir", type=str, default='./logs',
help="Where to save log files of the model.")
parser.add_argument("--exp", type=str, default="main",
help="ID of the experiment (multiple runs)")
parser.add_argument("--run", type=str, help="ID of the run")
parser.add_argument('--workers', type=int, default=8,
metavar='N', help='dataloader threads')
parser.add_argument('--seed', default=64, type=int, help='seed for initializing training. ')
#
# Inference only
#
parser.add_argument("--infer-list", default="voc12/val.txt", type=str)
parser.add_argument('--mask-output-dir', type=str, default=None, help='path where to save masks')
parser.add_argument("--resume", type=str, default=None, help="Snapshot \"ID,iter\" to load")
#
# Configuration
#
parser.add_argument(
'--cfg', dest='cfg_file', required=True,
help='Config file for training (and optionally testing)')
parser.add_argument(
'--set', dest='set_cfgs',
help='Set config keys. Key value sequence seperate by whitespace.'
'e.g. [key] [value] [key] [value]',
default=[], nargs='+')
def maybe_create_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def check_global_arguments(args):
args.cuda = torch.cuda.is_available()
print("Available threads: ", torch.get_num_threads())
args.logdir = os.path.join(args.logdir, args.exp, args.run)
maybe_create_dir(args.logdir)
#
# Model directories
#
args.snapshot_dir = os.path.join(args.snapshot_dir, args.exp, args.run)
maybe_create_dir(args.snapshot_dir)
def get_arguments(args_in):
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Dense Unsupervised Learning for Video Segmentation")
add_global_arguments(parser)
args = parser.parse_args(args_in)
check_global_arguments(args)
return args
================================================
FILE: requirements.txt
================================================
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
setproctitle
matplotlib
tensorboard
pyyaml
packaging
opencv-python
scikit-image
================================================
FILE: train.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
from __future__ import print_function
import os
import sys
import numpy as np
import time
import random
import setproctitle
from functools import partial
import torch
import torch.nn.functional as F
from datasets import *
from models import get_model
from base_trainer import BaseTrainer
from opts import get_arguments
from core.config import cfg, cfg_from_file, cfg_from_list
from utils.timer import Timer
from utils.stat_manager import StatManager
from utils.davis2017 import evaluate_semi
from labelprop.crw import CRW
from torch.utils.tensorboard import SummaryWriter
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
class Trainer(BaseTrainer):
def __init__(self, args, cfg):
super(Trainer, self).__init__(args, cfg)
# train loader for target domain
self.loader = get_dataloader(args, cfg, 'train')
# alias
self.denorm = self.loader.dataset.denorm
# val loaders for source and target domains
self.valloaders = get_dataloader(args, cfg, 'val')
# writers (only main)
self.writer_val = {}
for val_set in self.valloaders.keys():
logdir_val = os.path.join(args.logdir, val_set)
self.writer_val[val_set] = SummaryWriter(logdir_val)
# model
self.net = get_model(cfg, remove_layers=cfg.MODEL.REMOVE_LAYERS)
print("Train Net: ")
print(self.net)
# optimizer using different LR
net_params = self.net.parameter_groups(cfg.MODEL.LR, cfg.MODEL.WEIGHT_DECAY)
print("Optimising parameter groups: ")
for i, g in enumerate(net_params):
print("[{}]: # parameters: {}, lr = {:4.3e}".format(i, len(g["params"]), g["lr"]))
self.optim = self.get_optim(net_params, cfg.MODEL)
print("# of params: ", len(list(self.net.parameters())))
# LR scheduler
if cfg.MODEL.LR_SCHEDULER == "step":
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, \
step_size=cfg.MODEL.LR_STEP, \
gamma=cfg.MODEL.LR_GAMMA)
elif cfg.MODEL.LR_SCHEDULER == "linear": # linear decay
def lr_lambda(epoch):
mult = 1 - epoch / (float(self.cfg.TRAIN.NUM_EPOCHS) - self.start_epoch)
mult = mult ** self.cfg.MODEL.LR_POWER
#print("Linear Scheduler: mult = {:4.3f}".format(mult))
return mult
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda)
else:
self.scheduler = None
self.vis_batch = None
# using cuda
self.net.cuda()
self.crw = CRW(cfg.TEST)
# checkpoint management
self.checkpoint.create_model(self.net, self.optim)
if not args.resume is None:
self.start_epoch, self.best_score = self.checkpoint.load(args.resume, "cuda:0")
def step_seg(self, epoch, batch_src, key, temp=None, train=False, visualise=False, \
save_batch=False, writer=None, tag="train_src"):
frames, masks_gt, n_obj, seq_name = batch_src
# semi-supervised: select only the first
frames = frames.flatten(0,1)
masks_gt = masks_gt.flatten(0,1)
masks_gt = masks_gt[:, :n_obj.item()]
masks_ref = masks_gt.clone()
masks_ref[1:] *= 0
T = frames.shape[0]
fetch = {"res3": lambda x: x[0], \
"res4": lambda x: x[1], \
"key": lambda x: x[2]}
# number of iterations
bs = self.cfg.TRAIN.BATCH_SIZE
feats = []
t0 = time.time()
torch.cuda.empty_cache()
for t in range(0, T, bs):
# next frame
frames_batch = frames[t:t+bs].cuda()
# source forward pass
feats_ = self.net(frames_batch, embd_only=True)
feats.append(fetch[key](feats_).cpu())
feats = torch.cat(feats, 0)
print("Inference: {:4.3f}s".format(time.time() - t0))
sys.stdout.flush()
t0 = time.time()
outs = self.crw.forward(feats, masks_ref)
print("CRW propagation: {:4.3f}s".format(time.time() - t0))
sys.stdout.flush()
outs["masks_gt"] = masks_gt.argmax(1)
if visualise:
outs["frames"] = frames
self._visualise_seg(epoch, outs, writer, tag)
if save_batch:
# Saving batch for visualisation
# saving the batch for visualisation
self.save_vis_batch(tag, batch_src)
return outs
def step(self, epoch, batch_in, train=False, visualise=False, save_batch=False, writer=None, tag="train"):
frames1, frames2, affine1, affine2 = batch_in
assert frames1.size() == frames2.size(), "Frames shape mismatch"
# We could simply do
# images1 = frames1.flatten(0,1).cuda()
# images2 = frames2.flatten(0,1).cuda()
# Instead we pull the reference frame from the 2nd view
# to the first view so that the regularising branch is
# always in evaluation mode to save the GPU memory
B,T,C,H,W = frames1.shape
images1 = torch.cat((frames1, frames2[:, ::T]), 1)
images1 = images1.flatten(0,1).cuda()
images2 = frames2[:, 1:].flatten(0,1).cuda()
affine1 = affine1.flatten(0,1).cuda()
affine2 = affine2.flatten(0,1).cuda()
# source forward pass
losses, outs = self.net(images1, frames2=images2, T=T, \
affine=affine1, affine2=affine2, \
dbg=visualise)
if train:
self.optim.zero_grad()
losses["main"].backward()
self.optim.step()
if visualise:
self._visualise(epoch, outs, T, writer, tag)
if save_batch:
# Saving batch for visualisation
self.save_vis_batch(tag, batch_in)
# summarising the losses
# into python scalars
losses_ret = {}
for key, val in losses.items():
losses_ret[key] = val.mean().item()
return losses_ret, outs
def train_epoch(self, epoch):
stat = StatManager()
# adding stats for classes
timer = Timer("Epoch {}".format(epoch))
step = partial(self.step, train=True, visualise=False)
# training mode
self.net.train()
for i, batch in enumerate(self.loader):
save_batch = i == 0
#
# Forward pass
#
losses, _ = step(epoch, batch, save_batch=save_batch, tag="train")
for loss_key, loss_val in losses.items():
stat.update_stats(loss_key, loss_val)
# intermediate logging
if i % 10 == 0:
msg = "Loss [{:04d}]: ".format(i)
for loss_key, loss_val in losses.items():
msg += " {} {:.4f} | ".format(loss_key, loss_val)
msg += " | Im/Sec: {:.1f}".format(i * self.cfg.TRAIN.BATCH_SIZE / timer.get_stage_elapsed())
print(msg)
sys.stdout.flush()
for name, val in stat.items():
print("{}: {:4.3f}".format(name, val))
self.writer.add_scalar('all/{}'.format(name), val, epoch)
# plotting learning rate
for ii, l in enumerate(self.optim.param_groups):
print("Learning rate [{}]: {:4.3e}".format(ii, l['lr']))
self.writer.add_scalar('lr/enc_group_%02d' % ii, l['lr'], epoch)
# plotting moment distance
if stat.has_vals("lr_gamma"):
self.writer.add_scalar('hyper/gamma', stat.summarize_key("lr_gamma"), epoch)
if epoch % self.cfg.LOG.ITER_TRAIN == 0:
self.visualise_results(epoch, self.writer, "train", self.step)
def validation(self, epoch, writer, loader, tag=None, max_iter=None):
stat = StatManager()
if max_iter is None:
max_iter = len(loader)
# Fast test during the training
def eval_batch(batch):
loss, masks = self.step(epoch, batch, train=False, visualise=False)
for loss_key, loss_val in loss.items():
stat.update_stats(loss_key, loss_val)
return masks
self.net.eval()
print("Starting validation")
sys.stdout.flush()
for n, batch in enumerate(loader):
with torch.no_grad():
# note video length assumed 1
eval_batch(batch)
if not tag is None and not self.has_vis_batch(tag):
self.save_vis_batch(tag, batch)
checkpoint_score = 0.0
# total classification loss
for stat_key, stat_val in stat.items():
writer.add_scalar('all/{}'.format(stat_key), stat_val, epoch)
print('Loss {}: {:4.3f}'.format(stat_key, stat_val))
if not tag is None and epoch % self.cfg.LOG.ITER_TRAIN == 0:
self.visualise_results(epoch, writer, tag, self.step)
return checkpoint_score
def validation_seg(self, epoch, writer, loader, key="all", temp=None, tag=None, max_iter=None):
vis = key == "res4"
stat = StatManager()
if max_iter is None:
max_iter = len(loader)
if temp is None:
temp = self.cfg.TEST.TEMP
step_fn = partial(self.step_seg, key=key, temp=temp, train=False, visualise=vis, writer=writer)
# Fast test during the training
def eval_batch(n, batch):
tag_n = tag + "_{:02d}".format(n)
masks = step_fn(epoch, batch, tag=tag_n)
return masks
self.net.eval()
def davis_mask(masks):
masks = masks.cpu()
num_objects = int(masks.max())
tmp = torch.ones(num_objects, *masks.shape)
tmp = tmp * torch.arange(1, num_objects + 1)[:, None, None, None]
return (tmp == masks[None, ...]).long().numpy()
Js = {"M": [], "R": [], "D": []}
Fs = {"M": [], "R": [], "D": []}
timer = Timer("[Epoch {}] Validation-Seg".format(epoch))
tag_key = "{}_{}_{:3.2f}".format(tag, key, temp)
for n, batch in enumerate(loader):
seq_name = batch[-1][0]
print("Sequence: ", seq_name)
sys.stdout.flush()
with torch.no_grad():
masks_out = eval_batch(n, batch)
# second element is assumed to be always GT masks
masks_gt = davis_mask(masks_out["masks_gt"])
masks_pred = davis_mask(masks_out["masks_pred_idx"])
assert masks_gt.shape == masks_pred.shape
# converting to a digestible format
# [num_objects, seq_length, height, width]
if not tag_key is None and not self.has_vis_batch(tag_key):
self.save_vis_batch(tag_key, batch)
start_t = time.time()
metrics_res = evaluate_semi((masks_gt, ), (masks_pred, ))
J, F = metrics_res['J'], metrics_res['F']
print("Evaluation: {:4.3f}s".format(time.time() - start_t))
print("Jaccard: ", J["M"])
print("F-Score: ", F["M"])
for l in ("M", "R", "D"):
Js[l] += J[l]
Fs[l] += F[l]
msg = "{} | Im/Sec: {:.1f}".format(n, n * batch[0].shape[1] / timer.get_stage_elapsed())
print(msg)
sys.stdout.flush()
g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay']
# Generate dataframe for the general results
final_mean = (np.mean(Js["M"]) + np.mean(Fs["M"])) / 2.
g_res = [final_mean, \
np.mean(Js["M"]), np.mean(Js["R"]), np.mean(Js["D"]), \
np.mean(Fs["M"]), np.mean(Fs["R"]), np.mean(Fs["D"])]
for (name, val) in zip(g_measures, g_res):
writer.add_scalar('{}_{:3.2f}/{}'.format(key, temp, name), val, epoch)
print('{}: {:4.3f}'.format(name, val))
return final_mean
def train(args, cfg):
setproctitle.setproctitle("dense-ulearn | {}".format(args.run))
if args.seed is not None:
print("Setting the seed: {}".format(args.seed))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
trainer = Trainer(args, cfg)
timer = Timer()
def time_call(func, msg, *args, **kwargs):
timer.reset_stage()
val = func(*args, **kwargs)
print(msg + (" {:3.2}m".format(timer.get_stage_elapsed() / 60.)))
return val
for epoch in range(trainer.start_epoch, cfg.TRAIN.NUM_EPOCHS + 1):
# training 1 epoch
time_call(trainer.train_epoch, "Train epoch: ", epoch)
print("Epoch >>> {:02d} <<< ".format(epoch))
if epoch % cfg.LOG.ITER_VAL == 0:
for val_set in ("val_video", ):
time_call(trainer.validation, "Validation / {} / Val: ".format(val_set), \
epoch, trainer.writer_val[val_set], trainer.valloaders[val_set], tag=val_set)
best_layer = None
best_score = -1e10
for val_set in ("val_video_seg", ):
writer = trainer.writer_val[val_set]
loader = trainer.valloaders[val_set]
for layer in ("key", "res4"):
msg = ">>> Validation {} / {} <<<".format(layer, val_set)
score = time_call(trainer.validation_seg, msg, epoch, writer, loader, key=layer, tag=val_set)
if score > best_score:
best_score = score
best_layer = layer
print("Best score / layer: {:4.2f} / {}".format(best_score, best_layer))
if val_set =="val_video_seg":
trainer.checkpoint_best(best_score, epoch, best_layer)
if not trainer.scheduler is None and cfg.MODEL.LR_SCHED_USE_EPOCH:
trainer.scheduler.step()
def main():
args = get_arguments(sys.argv[1:])
# Reading the config
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
train(args, cfg)
if __name__ == "__main__":
main()
================================================
FILE: utils/checkpoints.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
import sys
import torch
class Checkpoint(object):
def __init__(self, path, max_n=3):
self.path = path
self.max_n = max_n
self.models = {}
self.checkpoints = []
def create_model(self, model, opt):
self.models = {}
self.models['model'] = model
self.models['opt'] = opt
def limit(self):
return self.max_n
def __len__(self):
return len(self.checkpoints)
def _get_full_path(self, suffix):
filename = self._filename(suffix)
return os.path.join(self.path, filename)
def clean(self):
n_remove = max(0, len(self.checkpoints) - self.max_n)
for i in range(n_remove):
self._rm(self.checkpoints[i])
self.checkpoints = self.checkpoints[n_remove:]
def _rm(self, suffix):
path = self._get_full_path(suffix)
if os.path.isfile(path):
os.remove(path)
def _filename(self, suffix):
return "{}.pth".format(suffix)
def load(self, path, location):
if len(path) > 0 and not os.path.isfile(path):
print("Snapshot {} not found".format(path))
data = torch.load(path, map_location=location)
data_mapped = {}
for key, val in data["model"].items():
data_mapped[key.replace("module.", "")] = val
self.models["model"].load_state_dict(data_mapped, strict=True)
return data["epoch"], data["score"]
def checkpoint(self, score, epoch, t):
suffix = "epoch{:03d}_score{:4.3f}_{}".format(epoch, score, t)
self.checkpoints.append(suffix)
path = self._get_full_path(suffix)
if not os.path.isfile(path):
torch.save({"model": self.models["model"].state_dict(),
"opt": self.models["opt"].state_dict(),
"score": score,
"epoch": epoch}, path)
# removing if more than allowed number of snapshots
self.clean()
================================================
FILE: utils/collections.py
================================================
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""A simple attribute dictionary used for representing configuration options."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttrDict(dict):
IMMUTABLE = '__immutable__'
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__[AttrDict.IMMUTABLE] = False
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if not self.__dict__[AttrDict.IMMUTABLE]:
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
else:
raise AttributeError(
'Attempted to set "{}" to "{}", but AttrDict is immutable'.
format(name, value)
)
def immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested AttrDicts.
"""
self.__dict__[AttrDict.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
for v in self.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
def is_immutable(self):
return self.__dict__[AttrDict.IMMUTABLE]
================================================
FILE: utils/davis2017.py
================================================
"""
Credit: https://github.com/davisvideochallenge/davis2017-evaluation.git
License: BSD 3-Clause
Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation
"""
import sys
import numpy as np
from utils.davis2017_metrics import db_eval_boundary, db_eval_iou
import utils.davis2017_utils as utils
def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric):
if all_res_masks.shape[0] > all_gt_masks.shape[0]:
sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!")
sys.exit()
elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
sys.stdout.write("\nThe number of predictions is less than ground truth. Padding with zero.")
zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2])
for ii in range(all_gt_masks.shape[0]):
if 'J' in metric:
j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
sys.stdout.flush()
if 'F' in metric:
f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
sys.stdout.flush()
return j_metrics_res, f_metrics_res
def evaluate_semi(all_gt_masks, all_res_masks, metric=('J', 'F'), debug=False):
metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric]
if 'T' in metric:
raise ValueError('Temporal metric not supported!')
if 'J' not in metric and 'F' not in metric:
raise ValueError('Metric possible values are J for IoU or F for Boundary')
# Containers
metrics_res = {}
if 'J' in metric:
metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
if 'F' in metric:
metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
for seq, (seq_gt_masks, seq_res_masks) in enumerate(zip(all_gt_masks, all_res_masks)):
seq_gt_masks, seq_res_masks = seq_gt_masks[:, 1:-1, :, :], seq_res_masks[:, 1:-1, :, :]
j_metrics_res, f_metrics_res = _evaluate_semisupervised(seq_gt_masks, seq_res_masks, None, metric)
for ii in range(seq_gt_masks.shape[0]):
seq_name = f'{seq}_{ii+1}'
if 'J' in metric:
[JM, JR, JD] = utils.db_statistics(j_metrics_res[ii])
metrics_res['J']["M"].append(JM)
metrics_res['J']["R"].append(JR)
metrics_res['J']["D"].append(JD)
metrics_res['J']["M_per_object"][seq_name] = JM
if 'F' in metric:
[FM, FR, FD] = utils.db_statistics(f_metrics_res[ii])
metrics_res['F']["M"].append(FM)
metrics_res['F']["R"].append(FR)
metrics_res['F']["D"].append(FD)
metrics_res['F']["M_per_object"][seq_name] = FM
# Show progress
if debug:
sys.stdout.write(seq + '\n')
sys.stdout.flush()
return metrics_res
================================================
FILE: utils/davis2017_metrics.py
================================================
"""
Credit: https://github.com/davisvideochallenge/davis2017-evaluation.git
License: BSD 3-Clause
Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation
"""
import math
import sys
import numpy as np
import cv2
cv2.setNumThreads(0)
from skimage.morphology import disk
def db_eval_iou(annotation, segmentation, void_pixels=None):
""" Compute region similarity as the Jaccard Index.
Arguments:
annotation (ndarray): binary annotation map.
segmentation (ndarray): binary segmentation map.
void_pixels (ndarray): optional mask with void pixels
Return:
jaccard (float): region similarity
"""
assert annotation.shape == segmentation.shape, \
f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
annotation = annotation.astype(np.bool)
segmentation = segmentation.astype(np.bool)
if void_pixels is not None:
assert annotation.shape == void_pixels.shape, \
f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
void_pixels = void_pixels.astype(np.bool)
else:
void_pixels = np.zeros_like(segmentation)
# Intersection between all sets
inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))
j = inters / union
if j.ndim == 0:
j = 1 if np.isclose(union, 0) else j
else:
j[np.isclose(union, 0)] = 1
return j
def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
assert annotation.shape == segmentation.shape
if void_pixels is not None:
assert annotation.shape == void_pixels.shape
if annotation.ndim == 3:
n_frames = annotation.shape[0]
f_res = np.zeros(n_frames)
for frame_id in range(n_frames):
void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
elif annotation.ndim == 2:
f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
else:
raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
return f_res
def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
"""
Compute mean,recall and decay from per-frame evaluation.
Calculates precision/recall for boundaries between foreground_mask and
gt_mask using morphological operators to speed it up.
Arguments:
foreground_mask (ndarray): binary segmentation image.
gt_mask (ndarray): binary annotated image.
void_pixels (ndarray): optional mask with void pixels
Returns:
F (float): boundaries F-measure
"""
assert np.atleast_3d(foreground_mask).shape[2] == 1
if void_pixels is not None:
void_pixels = void_pixels.astype(np.bool)
else:
void_pixels = np.zeros_like(foreground_mask).astype(np.bool)
bound_pix = bound_th if bound_th >= 1 else \
np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
# Get the pixel boundaries of both masks
fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))
# fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
# gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
# Get the intersection
gt_match = gt_boundary * fg_dil
fg_match = fg_boundary * gt_dil
# Area of the intersection
n_fg = np.sum(fg_boundary)
n_gt = np.sum(gt_boundary)
# % Compute precision and recall
if n_fg == 0 and n_gt > 0:
precision = 1
recall = 0
elif n_fg > 0 and n_gt == 0:
precision = 0
recall = 1
elif n_fg == 0 and n_gt == 0:
precision = 1
recall = 1
else:
precision = np.sum(fg_match) / float(n_fg)
recall = np.sum(gt_match) / float(n_gt)
# Compute F measure
if precision + recall == 0:
F = 0
else:
F = 2 * precision * recall / (precision + recall)
return F
def _seg2bmap(seg, width=None, height=None):
"""
From a segmentation, compute a binary boundary map with 1 pixel wide
boundaries. The boundary pixels are offset by 1/2 pixel towards the
origin from the actual segment boundary.
Arguments:
seg : Segments labeled from 1..k.
width : Width of desired bmap <= seg.shape[1]
height : Height of desired bmap <= seg.shape[0]
Returns:
bmap (ndarray): Binary boundary map.
David Martin <dmartin@eecs.berkeley.edu>
January 2003
"""
seg = seg.astype(np.bool)
seg[seg > 0] = 1
assert np.atleast_3d(seg).shape[2] == 1
width = seg.shape[1] if width is None else width
height = seg.shape[0] if height is None else height
h, w = seg.shape[:2]
ar1 = float(width) / float(height)
ar2 = float(w) / float(h)
assert not (
width > w | height > h | abs(ar1 - ar2) > 0.01
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
e = np.zeros_like(seg)
s = np.zeros_like(seg)
se = np.zeros_like(seg)
e[:, :-1] = seg[:, 1:]
s[:-1, :] = seg[1:, :]
se[:-1, :-1] = seg[1:, 1:]
b = seg ^ e | seg ^ s | seg ^ se
b[-1, :] = seg[-1, :] ^ e[-1, :]
b[:, -1] = seg[:, -1] ^ s[:, -1]
b[-1, -1] = 0
if w == width and h == height:
bmap = b
else:
bmap = np.zeros((height, width))
for x in range(w):
for y in range(h):
if b[y, x]:
j = 1 + math.floor((y - 1) + height / h)
i = 1 + math.floor((x - 1) + width / h)
bmap[j, i] = 1
return bmap
if __name__ == '__main__':
from davis2017.davis import DAVIS
from davis2017.results import Results
dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics')
results = Results(root_dir='examples/osvos')
# Test timing F measure
for seq in dataset.get_sequences():
all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True)
all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1]
all_res_masks = results.read_masks(seq, all_masks_id)
f_metrics_res = np.zeros(all_gt_masks.shape[:2])
for ii in range(all_gt_masks.shape[0]):
f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...])
# Run using to profile code: python -m cProfile -o f_measure.prof metrics.py
# snakeviz f_measure.prof
================================================
FILE: utils/davis2017_utils.py
================================================
"""
Credit: https://github.com/davisvideochallenge/davis2017-evaluation.git
License: BSD 3-Clause
Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation
"""
import os
import errno
import numpy as np
from PIL import Image
import warnings
def _pascal_color_map(N=256, normalized=False):
"""
Python implementation of the color map function for the PASCAL VOC data set.
Official Matlab version can be found in the PASCAL VOC devkit
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
"""
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7 - j)
g = g | (bitget(c, 1) << 7 - j)
b = b | (bitget(c, 2) << 7 - j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap / 255 if normalized else cmap
return cmap
def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):
im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)
if im.shape[:-1] != ann.shape:
raise ValueError('First two dimensions of `im` and `ann` must match')
if im.shape[-1] != 3:
raise ValueError('im must have three channels at the 3 dimension')
colors = colors or _pascal_color_map()
colors = np.asarray(colors, dtype=np.uint8)
mask = colors[ann]
fg = im * alpha + (1 - alpha) * mask
img = im.copy()
img[ann > 0] = fg[ann > 0]
if contour_thickness: # pragma: no cover
import cv2
for obj_id in np.unique(ann[ann > 0]):
contours = cv2.findContours((ann == obj_id).astype(
np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),
contour_thickness)
return img
def generate_obj_proposals(davis_root, subset, num_proposals, save_path):
dataset = DAVIS(davis_root, subset=subset, codalab=True)
for seq in dataset.get_sequences():
save_dir = os.path.join(save_path, seq)
if os.path.exists(save_dir):
continue
all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
img_size = all_gt_masks.shape[2:]
num_rows = int(np.ceil(np.sqrt(num_proposals)))
proposals = np.zeros((num_proposals, len(all_masks_id), *img_size))
height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist()
width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist()
ii = 0
prev_h, prev_w = 0, 0
for h in height_slices[1:]:
for w in width_slices[1:]:
proposals[ii, :, prev_h:h, prev_w:w] = 1
prev_w = w
ii += 1
if ii == num_proposals:
break
prev_h, prev_w = h, 0
if ii == num_proposals:
break
os.makedirs(save_dir, exist_ok=True)
for i, mask_id in enumerate(all_masks_id):
mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0)
save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path):
dataset = DAVIS(davis_root, subset=subset, codalab=True)
for seq in dataset.get_sequences():
gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
obj_swap = np.random.permutation(np.arange(gt_masks.shape[0]))
gt_masks = gt_masks[obj_swap, ...]
save_dir = os.path.join(save_path, seq)
os.makedirs(save_dir, exist_ok=True)
for i, mask_id in enumerate(all_masks_id):
mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0)
save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
def color_map(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap/255 if normalized else cmap
return cmap
def save_mask(mask, img_path):
if np.max(mask) > 255:
raise ValueError('Maximum id pixel value is 255')
mask_img = Image.fromarray(mask.astype(np.uint8))
mask_img.putpalette(color_map().flatten().tolist())
mask_img.save(img_path)
def db_statistics(per_frame_values):
""" Compute mean,recall and decay from per-frame evaluation.
Arguments:
per_frame_values (ndarray): per-frame evaluation
Returns:
M,O,D (float,float,float):
return evaluation statistics: mean,recall,decay.
"""
# strip off nan values
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
M = np.nanmean(per_frame_values)
O = np.nanmean(per_frame_values > 0.5)
N_bins = 4
ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1
ids = ids.astype(np.uint8)
D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3])
return M, O, D
def list_files(dir, extension=".png"):
return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)]
def force_symlink(file1, file2):
try:
os.symlink(file1, file2)
except OSError as e:
if e.errno == errno.EEXIST:
os.remove(file2)
os.symlink(file1, file2)
================================================
FILE: utils/palette.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import matplotlib.cm as cm
import numpy as np
from PIL import ImagePalette
def colormap(N=256):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
dtype = 'uint8'
cmap = []
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap.append((r, g, b))
return cmap
def apply_cmap(masks_pred, cmap):
canvas = np.zeros((masks_pred.shape[0], masks_pred.shape[1], 3))
for label in np.unique(masks_pred):
canvas[masks_pred == label, :] = cmap[label]
return canvas #np.transpose(canvas, [2,0,1])
def create_palette(colormap, num):
cmap = cm.get_cmap(colormap)
palette = ImagePalette.ImagePalette()
for n in range(num):
val = n / num
rgb = [int(255*x) for x in cmap(val)[:-1]]
palette.getcolor(tuple(rgb))
return palette
def custom_palette(nclasses, cname="rainbow"):
cmap = cm.get_cmap(cname, nclasses)
return cmap
================================================
FILE: utils/palette_davis.py
================================================
palette_str = '''0 0 0
128 0 0
0 128 0
128 128 0
0 0 128
128 0 128
0 128 128
128 128 128
64 0 0
191 0 0
64 128 0
191 128 0
64 0 128
191 0 128
64 128 128
191 128 128
0 64 0
128 64 0
0 191 0
128 191 0
0 64 128
128 64 128
22 22 22
23 23 23
24 24 24
25 25 25
26 26 26
27 27 27
28 28 28
29 29 29
30 30 30
31 31 31
32 32 32
33 33 33
34 34 34
35 35 35
36 36 36
37 37 37
38 38 38
39 39 39
40 40 40
41 41 41
42 42 42
43 43 43
44 44 44
45 45 45
46 46 46
47 47 47
48 48 48
49 49 49
50 50 50
51 51 51
52 52 52
53 53 53
54 54 54
55 55 55
56 56 56
57 57 57
58 58 58
59 59 59
60 60 60
61 61 61
62 62 62
63 63 63
64 64 64
65 65 65
66 66 66
67 67 67
68 68 68
69 69 69
70 70 70
71 71 71
72 72 72
73 73 73
74 74 74
75 75 75
76 76 76
77 77 77
78 78 78
79 79 79
80 80 80
81 81 81
82 82 82
83 83 83
84 84 84
85 85 85
86 86 86
87 87 87
88 88 88
89 89 89
90 90 90
91 91 91
92 92 92
93 93 93
94 94 94
95 95 95
96 96 96
97 97 97
98 98 98
99 99 99
100 100 100
101 101 101
102 102 102
103 103 103
104 104 104
105 105 105
106 106 106
107 107 107
108 108 108
109 109 109
110 110 110
111 111 111
112 112 112
113 113 113
114 114 114
115 115 115
116 116 116
117 117 117
118 118 118
119 119 119
120 120 120
121 121 121
122 122 122
123 123 123
124 124 124
125 125 125
126 126 126
127 127 127
128 128 128
129 129 129
130 130 130
131 131 131
132 132 132
133 133 133
134 134 134
135 135 135
136 136 136
137 137 137
138 138 138
139 139 139
140 140 140
141 141 141
142 142 142
143 143 143
144 144 144
145 145 145
146 146 146
147 147 147
148 148 148
149 149 149
150 150 150
151 151 151
152 152 152
153 153 153
154 154 154
155 155 155
156 156 156
157 157 157
158 158 158
159 159 159
160 160 160
161 161 161
162 162 162
163 163 163
164 164 164
165 165 165
166 166 166
167 167 167
168 168 168
169 169 169
170 170 170
171 171 171
172 172 172
173 173 173
174 174 174
175 175 175
176 176 176
177 177 177
178 178 178
179 179 179
180 180 180
181 181 181
182 182 182
183 183 183
184 184 184
185 185 185
186 186 186
187 187 187
188 188 188
189 189 189
190 190 190
191 191 191
192 192 192
193 193 193
194 194 194
195 195 195
196 196 196
197 197 197
198 198 198
199 199 199
200 200 200
201 201 201
202 202 202
203 203 203
204 204 204
205 205 205
206 206 206
207 207 207
208 208 208
209 209 209
210 210 210
211 211 211
212 212 212
213 213 213
214 214 214
215 215 215
216 216 216
217 217 217
218 218 218
219 219 219
220 220 220
221 221 221
222 222 222
223 223 223
224 224 224
225 225 225
226 226 226
227 227 227
228 228 228
229 229 229
230 230 230
231 231 231
232 232 232
233 233 233
234 234 234
235 235 235
236 236 236
237 237 237
238 238 238
239 239 239
240 240 240
241 241 241
242 242 242
243 243 243
244 244 244
245 245 245
246 246 246
247 247 247
248 248 248
249 249 249
250 250 250
251 251 251
252 252 252
253 253 253
254 254 254
255 255 255'''
import numpy as np
tensor = np.array([[float(x)/255 for x in line.split()] for line in palette_str.split('\n')])
from matplotlib.colors import ListedColormap
palette = ListedColormap(tensor, 'davis')
================================================
FILE: utils/stat_manager.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
class StatManager(object):
def __init__(self):
self.func_keys = {}
self.vals = {}
self.vals_count = {}
self.formats = {}
def items(self):
curr_vals = {}
for k in self.vals:
if self.has_vals(k):
curr_vals[k] = self.summarize_key(k)
return curr_vals.items()
def reset(self):
for k in self.vals:
self.vals[k] = 0.0
self.vals_count[k] = 0.0
def add_val(self, key, form="{:4.3f}"):
self.vals[key] = 0.0
self.vals_count[key] = 0.0
self.formats[key] = form
def get_val(self, key):
return self.vals[key], self.vals_count[key]
def add_compute(self, key, func, form="{:4.3f}"):
self.func_keys[key] = func
self.add_val(key)
self.formats[key] = form
def update_stats(self, key, val, count = 1):
if not key in self.vals:
self.add_val(key)
self.vals[key] += val
self.vals_count[key] += count
def compute_stats(self, a, b, size = 1):
for k, func in self.func_keys.iteritems():
self.vals[k] += func(a, b)
self.vals_count[k] += size
def has_vals(self, k):
return k in self.vals_count and \
self.vals_count[k] > 0
def summarize_key(self, k):
if self.has_vals(k) and abs(self.vals[k]) > 0.:
return self.vals[k] / self.vals_count[k]
else:
return 0
def summarize(self, epoch = 0, verbose = True):
if verbose:
out = "\tEpoch[{:03d}]".format(epoch)
for k in self.vals:
if self.has_vals(k):
out += (" / {} " + self.formats[k]).format(k, self.summarize_key(k))
print(out)
================================================
FILE: utils/sys_tools.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import os
def check_dir(base_path, name):
"""Make sure the directory exists"""
# create the directory
fullpath = os.path.join(base_path, name)
if not os.path.exists(fullpath):
os.makedirs(fullpath)
return fullpath
================================================
FILE: utils/timer.py
================================================
"""
Copyright (c) 2021 TU Darmstadt
Author: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>
License: Apache License 2.0
"""
import time
class Timer:
def __init__(self, starting_msg = None):
self.start = time.time()
self.stage_start = self.start
if starting_msg is not None:
print(starting_msg, time.ctime(time.time()))
def update_progress(self, progress):
self.elapsed = time.time() - self.start
self.est_total = self.elapsed / progress
self.est_remaining = self.est_total - self.elapsed
self.est_finish = int(self.start + self.est_total)
def stage(self, msg=None):
t = self.get_stage_elapsed()
self.reset_stage()
if not msg is None:
print("{:4.3f} elapsed: {}".format(t, msg))
return t
def str_est_finish(self):
return str(time.ctime(self.est_finish))
def get_stage_elapsed(self):
return time.time() - self.stage_start
def reset_stage(self):
self.stage_start = time.time()
gitextract_9fs7eg_t/
├── .gitignore
├── LICENSE
├── README.md
├── base_trainer.py
├── configs/
│ ├── kinetics.yaml
│ ├── oxuva.yaml
│ ├── tracknet.yaml
│ └── ytvos.yaml
├── core/
│ ├── __init__.py
│ └── config.py
├── datasets/
│ ├── __init__.py
│ ├── dataloader_base.py
│ ├── dataloader_infer.py
│ ├── dataloader_seg.py
│ ├── dataloader_video.py
│ └── daugm_video.py
├── infer_vos.py
├── labelprop/
│ ├── common.py
│ └── crw.py
├── launch/
│ ├── infer_vos.sh
│ ├── train.sh
│ └── utils.bash
├── models/
│ ├── __init__.py
│ ├── base.py
│ ├── framework.py
│ ├── net.py
│ └── resnet18.py
├── opts.py
├── requirements.txt
├── train.py
└── utils/
├── checkpoints.py
├── collections.py
├── davis2017.py
├── davis2017_metrics.py
├── davis2017_utils.py
├── palette.py
├── palette_davis.py
├── stat_manager.py
├── sys_tools.py
└── timer.py
SYMBOL INDEX (241 symbols across 27 files)
FILE: base_trainer.py
class BaseTrainer (line 25) | class BaseTrainer(object):
method __init__ (line 27) | def __init__(self, args, cfg):
method checkpoint_best (line 37) | def checkpoint_best(self, score, epoch, temp):
method get_optim (line 49) | def get_optim(params, cfg):
method set_lr (line 77) | def set_lr(optim, lr):
method _downsize (line 81) | def _downsize(self, x, mode="bilinear"):
method _visualise_seg (line 94) | def _visualise_seg(self, epoch, outs, writer, tag, S = 5):
method _visualise (line 128) | def _visualise(self, epoch, outs, T, writer, tag):
method save_vis_batch (line 233) | def save_vis_batch(self, key, batch):
method has_vis_batch (line 248) | def has_vis_batch(self, key):
method _mask_rgb (line 252) | def _mask_rgb(self, masks, image_norm=None, palette=None, alpha=0.3):
method _apply_cmap (line 267) | def _apply_cmap(self, mask_idx, palette=None, mask_conf=None, rand=True):
method _error_rgb (line 292) | def _error_rgb(self, error_mask, cmap = cm.get_cmap('jet'), image=None...
method _visualise_grid (line 305) | def _visualise_grid(self, writer, x_all, t, tag, T=1):
method visualise_results (line 319) | def visualise_results(self, epoch, writer, tag, step_func):
FILE: core/config.py
function assert_and_infer_cfg (line 114) | def assert_and_infer_cfg(make_immutable=True):
function merge_cfg_from_file (line 126) | def merge_cfg_from_file(cfg_filename):
function merge_cfg_from_cfg (line 135) | def merge_cfg_from_cfg(cfg_other):
function merge_cfg_from_list (line 140) | def merge_cfg_from_list(cfg_list):
function _merge_a_into_b (line 162) | def _merge_a_into_b(a, b, stack=None):
function _decode_cfg_value (line 190) | def _decode_cfg_value(v):
function _check_and_coerce_cfg_value_type (line 223) | def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
FILE: datasets/__init__.py
function get_sets (line 13) | def get_sets(task):
function get_dataloader (line 38) | def get_dataloader(args, cfg, split):
FILE: datasets/dataloader_base.py
class DLBase (line 10) | class DLBase(data.Dataset):
method __init__ (line 12) | def __init__(self, *args, **kwargs):
method _init_means (line 21) | def _init_means(self):
method _init_palette (line 25) | def _init_palette(self, num_classes):
method get_palette (line 28) | def get_palette(self):
method remove_labels (line 31) | def remove_labels(self, mask):
FILE: datasets/dataloader_infer.py
class DataSeg (line 18) | class DataSeg(DLBase):
method __init__ (line 20) | def __init__(self, cfg, split, ignore_labels=[], \
method __len__ (line 95) | def __len__(self):
method _mask2tensor (line 99) | def _mask2tensor(self, mask, num_classes=6):
method denorm (line 108) | def denorm(self, image):
method __getitem__ (line 126) | def __getitem__(self, index):
FILE: datasets/dataloader_seg.py
class DataSeg (line 18) | class DataSeg(DLBase):
method __init__ (line 20) | def __init__(self, cfg, split, ignore_labels=[], \
method __len__ (line 97) | def __len__(self):
method denorm (line 100) | def denorm(self, image):
method _mask2tensor (line 117) | def _mask2tensor(self, mask, num_classes=6):
method __getitem__ (line 126) | def __getitem__(self, index):
FILE: datasets/dataloader_video.py
class DataVideo (line 19) | class DataVideo(DLBase):
method __init__ (line 21) | def __init__(self, cfg, split, val=False):
method _init_augm (line 91) | def _init_augm(self, cfg):
method set_num_samples (line 136) | def set_num_samples(self, n):
method __len__ (line 140) | def __len__(self):
method denorm (line 143) | def denorm(self, image):
method _get_affine (line 160) | def _get_affine(self, params):
method _get_affine_inv (line 193) | def _get_affine_inv(self, affine, params):
method __getitem__ (line 209) | def __getitem__(self, index):
class DataVideoKinetics (line 274) | class DataVideoKinetics(DataVideo):
method __init__ (line 276) | def __init__(self, cfg, split):
method __len__ (line 308) | def __len__(self):
method __getitem__ (line 311) | def __getitem__(self, index):
FILE: datasets/daugm_video.py
class Compose (line 15) | class Compose:
method __init__ (line 17) | def __init__(self, segtransform):
method __call__ (line 20) | def __call__(self, args, *more_args):
class ToTensorMask (line 30) | class ToTensorMask:
method __toByteTensor (line 32) | def __toByteTensor(self, pic):
method __call__ (line 35) | def __call__(self, images, masks):
class CreateMask (line 44) | class CreateMask:
method __call__ (line 49) | def __call__(self, images):
class Normalize (line 57) | class Normalize:
method __init__ (line 60) | def __init__(self, mean, std=None):
method __call__ (line 70) | def __call__(self, images, masks):
class ApplyMask (line 83) | class ApplyMask:
method __init__ (line 85) | def __init__(self, ignore_label):
method __call__ (line 88) | def __call__(self, images, masks):
class GuidedRandHFlip (line 96) | class GuidedRandHFlip:
method __call__ (line 98) | def __call__(self, images, mask, affine=None):
class AffineIdentity (line 110) | class AffineIdentity(object):
method __call__ (line 112) | def __call__(self, images, masks, affine=None):
class MaskRandScaleCrop (line 119) | class MaskRandScaleCrop(object):
method __init__ (line 121) | def __init__(self, scale_from, scale_to):
method get_scale (line 126) | def get_scale(self):
method get_params (line 129) | def get_params(self, h, w, new_scale):
method __call__ (line 147) | def __call__(self, images, masks, affine=None):
class MaskScaleSmallest (line 198) | class MaskScaleSmallest(object):
method __init__ (line 200) | def __init__(self, smallest_range):
method __call__ (line 203) | def __call__(self, images, masks):
class MaskRandCrop (line 229) | class MaskRandCrop:
method __init__ (line 231) | def __init__(self, size, pad_if_needed=False):
method __pad (line 235) | def __pad(self, img, padding_mode='constant', fill=0):
method __call__ (line 249) | def __call__(self, images, masks):
class MaskCenterCrop (line 263) | class MaskCenterCrop:
method __init__ (line 265) | def __init__(self, size):
method __call__ (line 268) | def __call__(self, images, masks):
class MaskRandHFlip (line 276) | class MaskRandHFlip:
method __call__ (line 278) | def __call__(self, images, masks):
FILE: infer_vos.py
function mask2rgb (line 42) | def mask2rgb(mask, palette):
function mask_overlay (line 47) | def mask_overlay(mask, image, palette):
class ResultWriter (line 52) | class ResultWriter:
method __init__ (line 54) | def __init__(self, key, palette, out_path):
method save (line 60) | def save(self, frames, masks_pred, masks_conf, masks_gt, flags, fn, se...
function convert_dict (line 89) | def convert_dict(state_dict):
function mask2tensor (line 96) | def mask2tensor(mask, idx, num_classes=cfg.DATASET.NUM_CLASSES):
function configure_tracks (line 102) | def configure_tracks(masks_gt, tracks, num_objects):
function make_onehot (line 127) | def make_onehot(mask, HW):
function scale_smallest (line 138) | def scale_smallest(frame, a):
function valid_mask (line 144) | def valid_mask(mask):
function merge_mask_ids (line 152) | def merge_mask_ids(masks, key0):
function step_seg (line 163) | def step_seg(cfg, net, labelprop, frames, mask_init):
FILE: labelprop/common.py
class LabelPropVOS (line 12) | class LabelPropVOS(object):
method context_long (line 14) | def context_long(self):
method context_short (line 20) | def context_short(self, t):
method predict (line 30) | def predict(self, feats, masks, curr_feat):
class LabelPropVOS_CRW (line 42) | class LabelPropVOS_CRW(LabelPropVOS):
method __init__ (line 44) | def __init__(self, cfg):
method context_long (line 52) | def context_long(self, t0, t):
method context_short (line 55) | def context_short(self, t0, t):
method context_index (line 61) | def context_index(self, t0, t):
method predict (line 67) | def predict(self, feats, masks, curr_feat, ref_index=None, t=None):
FILE: labelprop/crw.py
class CRW (line 12) | class CRW(object):
method __init__ (line 15) | def __init__(self, cfg):
method _prep_context (line 36) | def _prep_context(self, feats, lbls, hw):
method forward (line 51) | def forward(self, feats, lbls):
function context_index_bank (line 134) | def context_index_bank(n_context, long_mem, N):
function batched_affinity (line 151) | def batched_affinity(query, keys, mask, temperature, topk, long_mem, dev...
function mem_efficient_batched_affinity (line 174) | def mem_efficient_batched_affinity(query, keys, mask, temperature, topk,...
class MaskedAttention (line 207) | class MaskedAttention(nn.Module):
method __init__ (line 212) | def __init__(self, radius, flat=True):
method mask (line 219) | def mask(self, H, W):
method index (line 224) | def index(self, H, W):
method make (line 229) | def make(self, H, W):
method flatten (line 244) | def flatten(self, D):
method make_index (line 247) | def make_index(self, H, W, pad=False):
method forward (line 255) | def forward(self, x):
FILE: models/__init__.py
function get_model (line 11) | def get_model(cfg, *args, **kwargs):
FILE: models/base.py
class BaseNet (line 12) | class BaseNet(nn.Module):
method __init__ (line 17) | def __init__(self):
method lr_mult (line 29) | def lr_mult(self):
method lr_mult_bias (line 34) | def lr_mult_bias(self):
method _is_learnable (line 39) | def _is_learnable(self, layer):
method _from_scratch (line 42) | def _from_scratch(self, net, ignore=[]):
method _freeze_bn (line 48) | def _freeze_bn(self, net, ignore=[]):
method _fix_bn (line 61) | def _fix_bn(self, layer):
method __set_grad_mode (line 69) | def __set_grad_mode(self, layer, mode, only_type=None):
method train (line 83) | def train(self, mode=True):
method parameter_groups (line 94) | def parameter_groups(self, base_lr, wd):
method _resize_as (line 130) | def _resize_as(x, y):
FILE: models/framework.py
class Framework (line 13) | class Framework(BaseNet):
method __init__ (line 15) | def __init__(self, cfg, net):
method parameter_groups (line 22) | def parameter_groups(self, base_lr, wd):
method _align (line 25) | def _align(self, x, t):
method _key_val (line 29) | def _key_val(self, ctr, q):
method _sample_index (line 44) | def _sample_index(self, x, T, N):
method _sample_from (line 73) | def _sample_from(self, x, index, T, N):
method _mark_from (line 102) | def _mark_from(self, x, index, T, N, fill_value=0):
method _cluster_grid (line 136) | def _cluster_grid(self, k1, k2, aff1, aff2, T, index=None):
method _aff_sample (line 188) | def _aff_sample(self, k1, k2, T):
method _pseudo_mask (line 220) | def _pseudo_mask(self, logits, T):
method _ref_loss (line 239) | def _ref_loss(self, x, y, N = 4):
method _ce_loss (line 250) | def _ce_loss(self, x, pseudo_map, T, eps=1e-5):
method _forward_reg (line 259) | def _forward_reg(self, frames2, norm):
method fetch_first (line 278) | def fetch_first(self, x1, x2, T):
method forward (line 290) | def forward(self, frames, frames2=None, mask=None, T=None, affine=None...
FILE: models/net.py
class MLP (line 12) | class MLP(nn.Sequential):
method __init__ (line 14) | def __init__(self, n_in, n_out):
class Net (line 22) | class Net(BaseNet):
method __init__ (line 24) | def __init__(self, cfg, backbone):
method lr_mult (line 31) | def lr_mult(self):
method lr_mult_bias (line 36) | def lr_mult_bias(self):
method forward (line 41) | def forward(self, frames, norm=True):
FILE: models/resnet18.py
class ResNet (line 15) | class ResNet(torch_resnet.ResNet):
method __init__ (line 17) | def __init__(self, *args, **kwargs):
method filter_layers (line 20) | def filter_layers(self, x):
method remove_layers (line 23) | def remove_layers(self, remove_layers=[]):
method modify (line 29) | def modify(self):
method forward (line 43) | def forward(self, x):
function _resnet (line 56) | def _resnet(arch, block, layers, pretrained, **kwargs):
function resnet18 (line 60) | def resnet18(pretrained='', remove_layers=[], train=True, **kwargs):
FILE: opts.py
function add_global_arguments (line 14) | def add_global_arguments(parser):
function maybe_create_dir (line 49) | def maybe_create_dir(path):
function check_global_arguments (line 53) | def check_global_arguments(args):
function get_arguments (line 67) | def get_arguments(args_in):
FILE: train.py
class Trainer (line 38) | class Trainer(BaseTrainer):
method __init__ (line 40) | def __init__(self, args, cfg):
method step_seg (line 105) | def step_seg(self, epoch, batch_src, key, temp=None, train=False, visu...
method step (line 160) | def step(self, epoch, batch_in, train=False, visualise=False, save_bat...
method train_epoch (line 205) | def train_epoch(self, epoch):
method validation (line 253) | def validation(self, epoch, writer, loader, tag=None, max_iter=None):
method validation_seg (line 295) | def validation_seg(self, epoch, writer, loader, key="all", temp=None, ...
function train (line 378) | def train(args, cfg):
function main (line 430) | def main():
FILE: utils/checkpoints.py
class Checkpoint (line 12) | class Checkpoint(object):
method __init__ (line 14) | def __init__(self, path, max_n=3):
method create_model (line 20) | def create_model(self, model, opt):
method limit (line 25) | def limit(self):
method __len__ (line 28) | def __len__(self):
method _get_full_path (line 31) | def _get_full_path(self, suffix):
method clean (line 35) | def clean(self):
method _rm (line 41) | def _rm(self, suffix):
method _filename (line 46) | def _filename(self, suffix):
method load (line 49) | def load(self, path, location):
method checkpoint (line 61) | def checkpoint(self, score, epoch, t):
FILE: utils/collections.py
class AttrDict (line 24) | class AttrDict(dict):
method __init__ (line 28) | def __init__(self, *args, **kwargs):
method __getattr__ (line 32) | def __getattr__(self, name):
method __setattr__ (line 40) | def __setattr__(self, name, value):
method immutable (line 52) | def immutable(self, is_immutable):
method is_immutable (line 65) | def is_immutable(self):
FILE: utils/davis2017.py
function _evaluate_semisupervised (line 12) | def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks...
function evaluate_semi (line 34) | def evaluate_semi(all_gt_masks, all_res_masks, metric=('J', 'F'), debug=...
FILE: utils/davis2017_metrics.py
function db_eval_iou (line 17) | def db_eval_iou(annotation, segmentation, void_pixels=None):
function db_eval_boundary (line 51) | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_t...
function f_measure (line 68) | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
function _seg2bmap (line 129) | def _seg2bmap(seg, width=None, height=None):
FILE: utils/davis2017_utils.py
function _pascal_color_map (line 14) | def _pascal_color_map(N=256, normalized=False):
function overlay_semantic_mask (line 41) | def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thick...
function generate_obj_proposals (line 67) | def generate_obj_proposals(davis_root, subset, num_proposals, save_path):
function generate_random_permutation_gt_obj_proposals (line 98) | def generate_random_permutation_gt_obj_proposals(davis_root, subset, sav...
function color_map (line 111) | def color_map(N=256, normalized=False):
function save_mask (line 132) | def save_mask(mask, img_path):
function db_statistics (line 140) | def db_statistics(per_frame_values):
function list_files (line 168) | def list_files(dir, extension=".png"):
function force_symlink (line 172) | def force_symlink(file1, file2):
FILE: utils/palette.py
function colormap (line 11) | def colormap(N=256):
function apply_cmap (line 30) | def apply_cmap(masks_pred, cmap):
function create_palette (line 39) | def create_palette(colormap, num):
function custom_palette (line 51) | def custom_palette(nclasses, cname="rainbow"):
FILE: utils/stat_manager.py
class StatManager (line 8) | class StatManager(object):
method __init__ (line 10) | def __init__(self):
method items (line 16) | def items(self):
method reset (line 23) | def reset(self):
method add_val (line 28) | def add_val(self, key, form="{:4.3f}"):
method get_val (line 33) | def get_val(self, key):
method add_compute (line 36) | def add_compute(self, key, func, form="{:4.3f}"):
method update_stats (line 41) | def update_stats(self, key, val, count = 1):
method compute_stats (line 48) | def compute_stats(self, a, b, size = 1):
method has_vals (line 54) | def has_vals(self, k):
method summarize_key (line 58) | def summarize_key(self, k):
method summarize (line 64) | def summarize(self, epoch = 0, verbose = True):
FILE: utils/sys_tools.py
function check_dir (line 9) | def check_dir(base_path, name):
FILE: utils/timer.py
class Timer (line 9) | class Timer:
method __init__ (line 10) | def __init__(self, starting_msg = None):
method update_progress (line 18) | def update_progress(self, progress):
method stage (line 24) | def stage(self, msg=None):
method str_est_finish (line 31) | def str_est_finish(self):
method get_stage_elapsed (line 34) | def get_stage_elapsed(self):
method reset_stage (line 37) | def reset_stage(self):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (173K chars).
[
{
"path": ".gitignore",
"chars": 70,
"preview": "__pycache__/\n*.pyc\n*.sw*\ndata/\nlibs/\nmodels/pretrained\nlogs\nsnapshots\n"
},
{
"path": "LICENSE",
"chars": 11371,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 6365,
"preview": "# Dense Unsupervised Learning for Video Segmentation\n\n[ 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "configs/kinetics.yaml",
"chars": 481,
"preview": "DATASET:\n ROOT: \"data\"\n SMALLEST_RANGE: [256,256]\n RND_CROP: False\n RND_ZOOM: True\n RND_ZOOM_RANGE: [.5, 1.]\n GUID"
},
{
"path": "configs/oxuva.yaml",
"chars": 455,
"preview": "DATASET:\n ROOT: \"data\"\n SMALLEST_RANGE: [256, 320]\n RND_CROP: True\n RND_ZOOM: True\n RND_ZOOM_RANGE: [.5, 1.]\n GUID"
},
{
"path": "configs/tracknet.yaml",
"chars": 507,
"preview": "DATASET:\n ROOT: \"data\"\n SMALLEST_RANGE: [256,256]\n RND_CROP: False\n RND_ZOOM: True\n RND_ZOOM_RANGE: [.5, 1.]\n GUID"
},
{
"path": "configs/ytvos.yaml",
"chars": 454,
"preview": "DATASET:\n ROOT: \"data\"\n SMALLEST_RANGE: [256, 320]\n RND_CROP: True\n RND_ZOOM: True\n RND_ZOOM_RANGE: [.5, 1.]\n GUID"
},
{
"path": "core/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "core/config.py",
"chars": 8224,
"preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom __futu"
},
{
"path": "datasets/__init__.py",
"chars": 2348,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "datasets/dataloader_base.py",
"chars": 929,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "datasets/dataloader_infer.py",
"chars": 5556,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "datasets/dataloader_seg.py",
"chars": 5115,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "datasets/dataloader_video.py",
"chars": 11809,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "datasets/daugm_video.py",
"chars": 8077,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "infer_vos.py",
"chars": 11124,
"preview": "\"\"\"\nSingle-scale inference\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nL"
},
{
"path": "labelprop/common.py",
"chars": 3576,
"preview": "\"\"\"\nBased on the inference routines from Jabri et al., (2020)\nCredit: https://github.com/ajabri/videowalk.git\nLicense: M"
},
{
"path": "labelprop/crw.py",
"chars": 8084,
"preview": "\"\"\"\nInference routines from Jabri et al., (2020)\nCredit: https://github.com/ajabri/videowalk.git\nLicense: MIT\n\"\"\"\n\nimpor"
},
{
"path": "launch/infer_vos.sh",
"chars": 2059,
"preview": "#!/bin/bash\n\n#\n# Arguments\n#\n\n# suffix for the output directory (see below)\nVER=v01\n\n# defines the path to the snapshot "
},
{
"path": "launch/train.sh",
"chars": 1403,
"preview": "#!/bin/bash\n\n# Set the following variables\n# The tensorboard logging will be creating in logs/<EXP>/<EXP_ID>\n# The snaps"
},
{
"path": "launch/utils.bash",
"chars": 448,
"preview": "#!/bin/bash\n\ncheck_rundir()\n{\n LOG_DIR=\"$1\"\n EXP_ID=\"$2\"\n\n if [ ! -d \"$LOG_DIR\" ]; then\n echo \"Creating directory "
},
{
"path": "models/__init__.py",
"chars": 488,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "models/base.py",
"chars": 4385,
"preview": "\"\"\"\nBase class for network models\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstad"
},
{
"path": "models/framework.py",
"chars": 10917,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "models/net.py",
"chars": 1464,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "models/resnet18.py",
"chars": 1871,
"preview": "\"\"\"\nBased on Jabri et al., (2020)\nCredit: https://github.com/ajabri/videowalk.git\nLicense: MIT\n\"\"\"\n\nimport os\nimport tor"
},
{
"path": "opts.py",
"chars": 2534,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "requirements.txt",
"chars": 206,
"preview": "# This file may be used to create an environment using:\n# $ conda create --name <env> --file <this file>\n# platform: lin"
},
{
"path": "train.py",
"chars": 14599,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "utils/checkpoints.py",
"chars": 2120,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "utils/collections.py",
"chars": 2301,
"preview": "# Copyright (c) 2017-present, Facebook, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you m"
},
{
"path": "utils/davis2017.py",
"chars": 3212,
"preview": "\"\"\"\nCredit: https://github.com/davisvideochallenge/davis2017-evaluation.git\nLicense: BSD 3-Clause\nCopyright (c) 2020, DA"
},
{
"path": "utils/davis2017_metrics.py",
"chars": 7035,
"preview": "\"\"\"\nCredit: https://github.com/davisvideochallenge/davis2017-evaluation.git\nLicense: BSD 3-Clause\nCopyright (c) 2020, DA"
},
{
"path": "utils/davis2017_utils.py",
"chars": 6141,
"preview": "\"\"\"\nCredit: https://github.com/davisvideochallenge/davis2017-evaluation.git\nLicense: BSD 3-Clause\nCopyright (c) 2020, DA"
},
{
"path": "utils/palette.py",
"chars": 1241,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "utils/palette_davis.py",
"chars": 2997,
"preview": "palette_str = '''0 0 0\n128 0 0\n0 128 0\n128 128 0\n0 0 128\n128 0 128\n0 128 128\n128 128 128\n64 0 0\n191 0 0\n64 128 0\n191 128"
},
{
"path": "utils/stat_manager.py",
"chars": 1924,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "utils/sys_tools.py",
"chars": 374,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
},
{
"path": "utils/timer.py",
"chars": 1047,
"preview": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License "
}
]
About this extraction
This page contains the full source code of the visinf/dense-ulearn-vos GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (160.6 KB), approximately 44.3k tokens, and a symbol index with 241 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.