Repository: visinf/dense-ulearn-vos Branch: main Commit: 88e5e3518cb0 Files: 40 Total size: 160.6 KB Directory structure: gitextract_9fs7eg_t/ ├── .gitignore ├── LICENSE ├── README.md ├── base_trainer.py ├── configs/ │ ├── kinetics.yaml │ ├── oxuva.yaml │ ├── tracknet.yaml │ └── ytvos.yaml ├── core/ │ ├── __init__.py │ └── config.py ├── datasets/ │ ├── __init__.py │ ├── dataloader_base.py │ ├── dataloader_infer.py │ ├── dataloader_seg.py │ ├── dataloader_video.py │ └── daugm_video.py ├── infer_vos.py ├── labelprop/ │ ├── common.py │ └── crw.py ├── launch/ │ ├── infer_vos.sh │ ├── train.sh │ └── utils.bash ├── models/ │ ├── __init__.py │ ├── base.py │ ├── framework.py │ ├── net.py │ └── resnet18.py ├── opts.py ├── requirements.txt ├── train.py └── utils/ ├── checkpoints.py ├── collections.py ├── davis2017.py ├── davis2017_metrics.py ├── davis2017_utils.py ├── palette.py ├── palette_davis.py ├── stat_manager.py ├── sys_tools.py └── timer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ *.pyc *.sw* data/ libs/ models/pretrained logs snapshots ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2021 TU Darmstadt Author: Nikita Araslanov Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Dense Unsupervised Learning for Video Segmentation [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Framework](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?&logo=PyTorch&logoColor=white)](https://pytorch.org/) This repository contains the official implementation of our paper: **Dense Unsupervised Learning for Video Segmentation**
[Nikita Araslanov](https://arnike.github.io), [Simone Schaub-Mayer](https://schaubsi.github.io) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/visinf/team_members/sroth/sroth.en.jsp)
To appear at NeurIPS*2021. [[paper](https://openreview.net/pdf?id=i8kfkuiCJCI)] [[supp](https://openreview.net/attachment?id=i8kfkuiCJCI&name=supplementary_material)] [[talk](https://youtu.be/tSBWZ6nYld0)] [[example results](https://youtu.be/BqVtZJSLOzg)] [[arXiv](https://arxiv.org/abs/2111.06265)] | drawing
| |:--:| |

We efficiently learn spatio-temporal correspondences
without any supervision, and achieve state-of-the-art
accuracy of video object segmentation.

| Contact: Nikita Araslanov *fname.lname* (at) visinf.tu-darmstadt.de --- ## Installation **Requirements.** To reproduce our results, we recommend Python >=3.6, PyTorch >=1.4, CUDA >=10.0. At least one Titan X GPUs (12GB) or equivalent is required. The code was primarily developed under PyTorch 1.8 on a single A100 GPU. The following steps will set up a local copy of the repository. 1. Create conda environment: ``` conda create --name dense-ulearn-vos source activate dense-ulearn-vos ``` 2. Install PyTorch >=1.4 (see [PyTorch instructions](https://pytorch.org/get-started/locally/)). For example on Linux, run: ``` conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch ``` 3. Install the dependencies: ``` pip install -r requirements.txt ``` 4. Download the data: | Dataset | Website | Target directory with video sequences | |:-:|:-:|:--| | YouTube-VOS | [Link](https://competitions.codalab.org/competitions/19544#participate-get-data) | `data/ytvos/train/JPEGImages/` | | OxUvA | [Link](https://oxuva.github.io/long-term-tracking-benchmark/) | `data/OxUvA/images/dev/` | | TrackingNet | [Link](https://github.com/SilvioGiancola/TrackingNet-devkit) | `data/tracking/train/jpegs/` | | Kinetics-400 | [Link](https://deepmind.com/research/open-source/kinetics) | `data/kinetics400/video_jpeg/train/` | The last column in this table specifies a path to subdirectories (relative to the project root) containing images of video frames. You can obviously use a different path structure. In this case, you will need to adjust the paths in `data/filelists/` for every dataset accordingly. 5. Download filelists: ``` cd data/filelists bash download.sh ``` This will download lists of training and validation paths for all datasets. ## Training We following bash script will train a ResNet-18 model from scratch on one of the four supported datasets (see above): ``` bash ./launch/train.sh [ytvos|oxuva|track|kinetics] ``` We also provide our final models for download. | Dataset | Mean J&F (DAVIS-2017) | Link | MD5 | |---|:-:|:--:|---| | OxUvA | 65.3 | [oxuva_e430_res4.pth (132M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/oxuva_e430_res4.pth) | `af541[...]d09b3` | | YouTube-VOS | 69.3 | [ytvos_e060_res4.pth (132M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/ytvos_e060_res4.pth) | `c3ae3[...]55faf` | | TrackingNet | 69.4 | [trackingnet_e088_res4.pth (88M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/trackingnet_e088_res4.pth) | `3e7e9[...]95fa9` | | Kinetics-400 | 68.7 | [kinetics_e026_res4.pth (88M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/kinetics_e026_res4.pth) | `086db[...]a7d98` | ## Inference and evaluation ### Inference To run the inference use `launch/infer_vos.sh`: ``` bash ./launch/infer_vos.sh [davis|ytvos] ``` The first argument selects the validation dataset to use (`davis` for DAVIS-2017; `ytvos` for YouTube-VOS). The bash variables declared in the script further help to set up the paths for reading the data and the pre-trained models as well as the output directory: * `EXP`, `RUN_ID` and `SNAPSHOT` determine the pre-trained model to load. * `VER` specifies a suffix for the output directory (in case you would like to experiment with different configurations for label propagation). Please, refer to `launch/infer_vos.sh` for their usage. The inference script will create two directories with the result: `[res3|res4|key]_vos` and `[res3|res4|key]_vis`, where the prefix corresponds to the codename of the output CNN layer used in the evaluation (selected in `infer_vos.sh` using `KEY` variable). The `vos`-directory contains the segmentation result ready for evaluation; the `vis`-directory produces the results for visualisation purposes. You can optionally disable generating the visualisation by setting `VERBOSE=False` in `infer_vos.py`. ### Evaluation: DAVIS-2017 Please use the official [evaluation package](https://github.com/davisvideochallenge/davis2017-evaluation). Install the repository, then simply run: ``` python evaluation_method.py --task semi-supervised --davis_path data/davis2017 --results_path ``` ### Evaluation: YouTube-VOS 2018 Please use the official [CodaLab evaluation server](https://competitions.codalab.org/competitions/19544#participate-submit_results). To create the submission, rename the `vos`-directory to `Annotations` and compress it to `Annotations.zip` for uploading. ## Acknowledgements We thank PyTorch contributors and [Allan Jabri](https://ajabri.github.io) for releasing [their implementation](https://github.com/ajabri/videowalk) of the label propagation. ## Citation We hope you find our work useful. If you would like to acknowledge it in your project, please use the following citation: ``` @inproceedings{Araslanov:2021:DUL, author = {Araslanov, Nikita and Simone Schaub-Mayer and Roth, Stefan}, title = {Dense Unsupervised Learning for Video Segmentation}, booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, volume = {34}, year = {2021} } ``` ================================================ FILE: base_trainer.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os import torch import math import numpy as np import torch.nn.functional as F import torchvision.utils as vutils from torch.utils.tensorboard import SummaryWriter from torch.optim.optimizer import Optimizer from utils.checkpoints import Checkpoint from utils.palette_davis import palette as palette_davis from PIL import Image from matplotlib import cm class BaseTrainer(object): def __init__(self, args, cfg): self.args = args self.cfg = cfg self.start_epoch = 0 self.best_score = -1e16 self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 3) logdir = os.path.join(args.logdir, 'train') self.writer = SummaryWriter(logdir) def checkpoint_best(self, score, epoch, temp): if score > self.best_score: print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch)) self.best_score = score self.checkpoint.checkpoint(score, epoch, temp) return True return False @staticmethod def get_optim(params, cfg): if not hasattr(torch.optim, cfg.OPT): print("Optimiser {} not supported".format(cfg.OPT)) raise NotImplementedError optim = getattr(torch.optim, cfg.OPT) if cfg.OPT == 'Adam': print("Using Adam >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY)) upd = torch.optim.Adam(params, lr=cfg.LR, \ betas=(cfg.BETA1, 0.999), \ weight_decay=cfg.WEIGHT_DECAY) elif cfg.OPT == 'SGD': print("Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY)) upd = torch.optim.SGD(params, lr=cfg.LR, \ momentum=cfg.MOMENTUM, \ nesterov=cfg.OPT_NESTEROV, \ weight_decay=cfg.WEIGHT_DECAY) else: upd = optim(params, lr=cfg.LR) upd.zero_grad() return upd @staticmethod def set_lr(optim, lr): for param_group in optim.param_groups: param_group['lr'] = lr def _downsize(self, x, mode="bilinear"): x = x.float() if x.dim() == 3: x = x.unsqueeze(1) scale = min(*self.cfg.TB.IM_SIZE) / min(x.shape[-1], x.shape[-2]) if mode == "nearest": x = F.interpolate(x, scale_factor=scale, mode="nearest") else: x = F.interpolate(x, scale_factor=scale, mode=mode, align_corners=True) return x.squeeze(1) def _visualise_seg(self, epoch, outs, writer, tag, S = 5): def with_frame(image, mask, alpha=0.3): return alpha * image + (1 - alpha) * mask frames = outs["frames"][::S] frames_norm = self.denorm(frames.cpu().clone()) frames_down = self._downsize(frames_norm) T,C,h,w = frames_down.shape visuals = [] visuals.append(frames_down) if "masks_gt" in outs: mask_rgb_gt = self._apply_cmap(outs["masks_gt"][::S].cpu(), palette_davis, rand=False) mask_rgb_gt = self._downsize(mask_rgb_gt) mask_rgb_gt = with_frame(frames_down, mask_rgb_gt) visuals.append(mask_rgb_gt) mask_rgb_idx = self._apply_cmap(outs["masks_pred_idx"][::S].cpu(), palette_davis, rand=False) mask_rgb_idx = self._downsize(mask_rgb_idx) mask_rgb_idx = with_frame(frames_down, mask_rgb_idx) visuals.append(mask_rgb_idx) #if "masks_pred_conf" in outs: conf = self._downsize(outs["masks_pred_conf"][::S].cpu()) conf_rgb = self._error_rgb(conf, cm.get_cmap("plasma"), frames_down, 0.3) visuals.append(conf_rgb) visuals = [x.float() for x in visuals] visuals = torch.cat(visuals, -1) self._visualise_grid(writer, visuals, epoch, tag) def _visualise(self, epoch, outs, T, writer, tag): visuals = [] def overlay(mask, image, alpha=0.3): return alpha * image + (1 - alpha) * mask frames_orig = outs["frames_orig"] frames_orig = self.denorm(frames_orig.cpu().clone()) frames_orig = self._downsize(frames_orig) visuals.append(frames_orig) frames = outs["frames"] frames_norm = self.denorm(frames.cpu().clone()) frames_down = self._downsize(frames_norm) if "grid_mask" in outs: val_mask = outs["grid_mask"] val_mask = self._downsize(val_mask) val_mask = val_mask.unsqueeze(1).expand(-1,3,-1,-1).cpu() val_mask = overlay(val_mask, frames_orig) visuals.append(val_mask) if "map_target" in outs: val = outs["map_target"] val = self._apply_cmap(val) val = self._downsize(val, "nearest") val = overlay(val, frames_down) visuals.append(val) if "map_soft" in outs: val = outs["map_soft"] val = self._mask_rgb(val) val = self._downsize(val) visuals.append(val) visuals.append(frames_down) frames2 = outs["frames2"] frames2_norm = self.denorm(frames2.cpu().clone()) frames2_down = self._downsize(frames2_norm) visuals.append(frames2_down) if "map" in outs: val = outs["map"] val = self._apply_cmap(val) val = self._downsize(val, "nearest").cpu() val = overlay(val, frames2_down) visuals.append(val) if "map_target_soft" in outs: val = outs["map_target_soft"] val = self._mask_rgb(val) val = self._downsize(val) visuals.append(val) # embedding error mask if "error_map" in outs: err_mask = outs["error_map"] err_mask = (err_mask - err_mask.min()) / (err_mask.max() - err_mask.min() + 1e-8) err_mask_rgb = self._error_rgb(err_mask, cmap=cm.get_cmap("plasma"), alpha=0.5) err_mask_rgb = self._downsize(err_mask_rgb) visuals.append(err_mask_rgb) if "aff_mask1" in outs: aff_mask = outs["aff_mask1"].unsqueeze(1).expand(-1,3,-1,-1).cpu() aff_mask = self._downsize(aff_mask) aff_frames = frames_orig.clone() aff_frames[::T] = overlay(aff_mask, aff_frames[::T], 0.5) visuals.append(aff_frames) aff_mask1 = self._error_rgb(outs["aff11"], cm.get_cmap("inferno")) aff_mask1 = self._downsize(aff_mask1) aff_mask1 = overlay(aff_mask1, frames_orig, 0.3) visuals.append(aff_mask1) aff_mask2 = self._error_rgb(outs["aff12"], cm.get_cmap("inferno")) aff_mask2 = self._downsize(aff_mask2) aff_mask2 = overlay(aff_mask2, frames2_down, 0.3) visuals.append(aff_mask2) if "aff_mask2" in outs: aff_mask = outs["aff_mask2"].unsqueeze(1).expand(-1,3,-1,-1).cpu() aff_mask = self._downsize(aff_mask) aff_frames = frames_down.clone() aff_frames[::T] = overlay(aff_mask, aff_frames[::T], 0.5) visuals.append(aff_frames) aff_mask1 = self._error_rgb(outs["aff21"], cm.get_cmap("inferno")) aff_mask1 = self._downsize(aff_mask1) aff_mask1 = overlay(aff_mask1, frames_orig, 0.3) visuals.append(aff_mask1) aff_mask2 = self._error_rgb(outs["aff22"], cm.get_cmap("inferno")) aff_mask2 = self._downsize(aff_mask2) aff_mask2 = overlay(aff_mask2, frames2_down, 0.3) visuals.append(aff_mask2) visuals = [x.cpu().float() for x in visuals] visuals = torch.cat(visuals, -1) self._visualise_grid(writer, visuals, epoch, tag, 4 * T) def save_vis_batch(self, key, batch): if self.vis_batch is None: self.vis_batch = {} if key in self.vis_batch: return batch_items = [] for el in batch: el = el.clone().cpu() if torch.is_tensor(el) else el batch_items.append(el) self.vis_batch[key] = batch_items def has_vis_batch(self, key): return (not self.vis_batch is None and \ key in self.vis_batch) def _mask_rgb(self, masks, image_norm=None, palette=None, alpha=0.3): if palette is None: palette = self.loader.dataset.palette # visualising masks masks_conf, masks_idx = torch.max(masks, 1) masks_conf = masks_conf - F.relu(masks_conf - 1, 0) masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), palette, mask_conf=masks_conf.cpu()) if not image_norm is None: return alpha * image_norm + (1 - alpha) * masks_idx_rgb return masks_idx_rgb def _apply_cmap(self, mask_idx, palette=None, mask_conf=None, rand=True): if palette is None: palette = self.loader.dataset.palette ignore_mask = (mask_idx == -1).cpu() # cycle if rand: memsize = self.cfg.TRAIN.BATCH_SIZE * self.cfg.MODEL.GRID_SIZE**2 mask_idx = ((mask_idx + 1) * 123) % memsize # convert mask to RGB mask = mask_idx.cpu().numpy().astype(np.uint32) mask_rgb = palette(mask) mask_rgb = torch.from_numpy(mask_rgb[:,:,:,:3]) mask_rgb[ignore_mask] *= 0 mask_rgb = mask_rgb.permute(0,3,1,2) if not mask_conf is None: # entropy mask_rgb *= mask_conf[:, None, :, :] return mask_rgb def _error_rgb(self, error_mask, cmap = cm.get_cmap('jet'), image=None, alpha=0.3): error_np = error_mask.cpu().numpy() # remove alpha channel error_rgb = cmap(error_np)[:, :, :, :3] error_rgb = np.transpose(error_rgb, (0,3,1,2)) error_rgb = torch.from_numpy(error_rgb) if not image is None: return alpha * image + (1 - alpha) * error_rgb return error_rgb def _visualise_grid(self, writer, x_all, t, tag, T=1): # adding the labels to images bs, ch, h, w = x_all.size() x_all_new = torch.zeros(T, ch, h, w) for b in range(bs): x_all_new[b % T] = x_all[b] if (b + 1) % T == 0: summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9).numpy() writer.add_image(tag + "_{:02d}".format(b // T), summary_grid, t) x_all_new.zero_() def visualise_results(self, epoch, writer, tag, step_func): # visualising self.net.eval() with torch.no_grad(): step_func(epoch, self.vis_batch[tag], \ train=False, visualise=True, \ writer=writer, tag=tag) ================================================ FILE: configs/kinetics.yaml ================================================ DATASET: ROOT: "data" SMALLEST_RANGE: [256,256] RND_CROP: False RND_ZOOM: True RND_ZOOM_RANGE: [.5, 1.] GUIDED_HFLIP: True VIDEO_LEN: 5 FRAME_GAP: 10 MP4: True TRAIN: BATCH_SIZE: 16 NUM_EPOCHS: 80 TASK: "kinetics400" MODEL: LR: 0.001 CE_REF: 0.1 OPT: "SGD" LR_SCHEDULER: "step" LR_GAMMA: 0.5 LR_STEP: 20 WEIGHT_DECAY: 0.0005 LOG: ITER_TRAIN: 8 ITER_VAL: 2 TEST: KNN: 10 CXT_SIZE: 20 RADIUS: 12 TEMP: 0.05 TB: IM_SIZE: [196, 196] ================================================ FILE: configs/oxuva.yaml ================================================ DATASET: ROOT: "data" SMALLEST_RANGE: [256, 320] RND_CROP: True RND_ZOOM: True RND_ZOOM_RANGE: [.5, 1.] GUIDED_HFLIP: True VIDEO_LEN: 5 FRAME_GAP: 10 TRAIN: BATCH_SIZE: 16 NUM_EPOCHS: 500 TASK: "OxUvA" MODEL: LR: 0.0001 OPT: "Adam" LR_SCHEDULER: "step" LR_GAMMA: 0.5 LR_STEP: 100 WEIGHT_DECAY: 0.0005 LOG: ITER_TRAIN: 20 ITER_VAL: 10 TEST: KNN: 10 CXT_SIZE: 20 RADIUS: 12 TEMP: 0.05 TB: IM_SIZE: [196, 196] ================================================ FILE: configs/tracknet.yaml ================================================ DATASET: ROOT: "data" SMALLEST_RANGE: [256,256] RND_CROP: False RND_ZOOM: True RND_ZOOM_RANGE: [.5, 1.] GUIDED_HFLIP: True VIDEO_LEN: 5 FRAME_GAP: 10 TRAIN: BATCH_SIZE: 16 NUM_EPOCHS: 250 TASK: "TrackingNet" BLOCK_BN: False STOP_GRAD: True MODEL: LR: 0.001 CE_REF: 0.1 OPT: "SGD" LR_SCHEDULER: "step" LR_GAMMA: 0.9 LR_STEP: 20 WEIGHT_DECAY: 0.0005 LOG: ITER_TRAIN: 10 ITER_VAL: 2 TEST: KNN: 10 CXT_SIZE: 20 RADIUS: 12 TEMP: 0.05 TB: IM_SIZE: [196, 196] ================================================ FILE: configs/ytvos.yaml ================================================ DATASET: ROOT: "data" SMALLEST_RANGE: [256, 320] RND_CROP: True RND_ZOOM: True RND_ZOOM_RANGE: [.5, 1.] GUIDED_HFLIP: True VIDEO_LEN: 5 FRAME_GAP: 2 TRAIN: BATCH_SIZE: 16 NUM_EPOCHS: 500 TASK: "YTVOS" MODEL: LR: 0.0001 OPT: "Adam" LR_SCHEDULER: "step" LR_GAMMA: 0.5 LR_STEP: 100 WEIGHT_DECAY: 0.0005 LOG: ITER_TRAIN: 20 ITER_VAL: 10 TEST: KNN: 10 CXT_SIZE: 20 RADIUS: 12 TEMP: 0.05 TB: IM_SIZE: [196, 196] ================================================ FILE: core/__init__.py ================================================ ================================================ FILE: core/config.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import yaml import six import os import os.path as osp import copy from ast import literal_eval import numpy as np from packaging import version from utils.collections import AttrDict __C = AttrDict() # Consumers can get config by: # from fast_rcnn_config import cfg cfg = __C __C.NUM_GPUS = 1 # Random note: avoid using '.ON' as a config key since yaml converts it to True; # prefer 'ENABLED' instead # ---------------------------------------------------------------------------- # # Training schedule options # ---------------------------------------------------------------------------- # __C.TRAIN = AttrDict() __C.TRAIN.BATCH_SIZE = 20 __C.TRAIN.NUM_EPOCHS = 15 __C.TRAIN.NUM_WORKERS = 4 __C.TRAIN.TASK = "YTVOS" __C.TRAIN.BLOCK_BN = False __C.TRAIN.STOP_GRAD = True # ---------------------------------------------------------------------------- # # Dataset options (+ augmentation options) # ---------------------------------------------------------------------------- # __C.DATASET = AttrDict() __C.DATASET.CROP_SIZE = [256,256] __C.DATASET.SMALLEST_RANGE = [256,320] __C.DATASET.RND_CROP = False __C.DATASET.RND_HFLIP = True __C.DATASET.MP4 = False __C.DATASET.RND_ZOOM = False __C.DATASET.RND_ZOOM_RANGE = [.5,1.] __C.DATASET.GUIDED_HFLIP = True __C.DATASET.ROOT = "data" __C.DATASET.VIDEO_LEN = 5 __C.DATASET.FRAME_GAP = 5 __C.DATASET.VAL_FRAME_GAP = 2 # size of the augmentation # (how many augmented copies to create for 1 image) __C.DATASET.NUM_CLASSES = 7 # inference-time parameters __C.TEST = AttrDict() __C.TEST.RADIUS = 12 __C.TEST.TEMP = 0.05 __C.TEST.KNN = 10 __C.TEST.CXT_SIZE = 20 __C.TEST.INPUT_SIZE = -1 __C.TEST.KEY = "res4" # ---------------------------------------------------------------------------- # # Model options # ---------------------------------------------------------------------------- # __C.MODEL = AttrDict() __C.MODEL.ARCH = 'resnet18' __C.MODEL.LR_SCHEDULER = 'poly' __C.MODEL.LR_STEP = 5 __C.MODEL.LR_GAMMA = 0.1 # divide by 10 every LR_STEP epochs __C.MODEL.LR_POWER = 1.0 __C.MODEL.LR_SCHED_USE_EPOCH = True __C.MODEL.OPT = 'Adam' __C.MODEL.OPT_NESTEROV = False __C.MODEL.LR = 3e-4 __C.MODEL.BETA1 = 0.5 __C.MODEL.MOMENTUM = 0.9 __C.MODEL.WEIGHT_DECAY = 1e-5 __C.MODEL.REMOVE_LAYERS = [] __C.MODEL.FEATURE_DIM = 128 __C.MODEL.GRID_SIZE = 8 __C.MODEL.GRID_SIZE_REF = 4 __C.MODEL.CE_REF = 0.1 # ---------------------------------------------------------------------------- # # Options for refinement # ---------------------------------------------------------------------------- # __C.LOG = AttrDict() __C.LOG.ITER_VAL = 1 __C.LOG.ITER_TRAIN = 10 # ---------------------------------------------------------------------------- # # Model options # ---------------------------------------------------------------------------- # __C.TB = AttrDict() __C.TB.IM_SIZE = (196, 196) # image resolution # [Infered value] __C.CUDA = False __C.DEBUG = False # [Infered value] __C.PYTORCH_VERSION_LESS_THAN_040 = False def assert_and_infer_cfg(make_immutable=True): """Call this function in your script after you have finished setting all cfg values that are necessary (e.g., merging a config from a file, merging command line config options, etc.). By default, this function will also mark the global cfg as immutable to prevent changing the global cfg settings during script execution (which can lead to hard to debug errors or code that's harder to understand than is necessary). """ if make_immutable: cfg.immutable(True) def merge_cfg_from_file(cfg_filename): """Load a yaml config file and merge it into the global config.""" with open(cfg_filename, 'r') as f: yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader)) _merge_a_into_b(yaml_cfg, __C) cfg_from_file = merge_cfg_from_file def merge_cfg_from_cfg(cfg_other): """Merge `cfg_other` into the global config.""" _merge_a_into_b(cfg_other, __C) def merge_cfg_from_list(cfg_list): """Merge config keys, values in a list (e.g., from command line) into the global config. For example, `cfg_list = ['TEST.NMS', 0.5]`. """ assert len(cfg_list) % 2 == 0 for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): key_list = full_key.split('.') d = __C for subkey in key_list[:-1]: assert subkey in d, 'Non-existent key: {}'.format(full_key) d = d[subkey] subkey = key_list[-1] assert subkey in d, 'Non-existent key: {}'.format(full_key) value = _decode_cfg_value(v) value = _check_and_coerce_cfg_value_type( value, d[subkey], subkey, full_key ) d[subkey] = value cfg_from_list = merge_cfg_from_list def _merge_a_into_b(a, b, stack=None): """Merge config dictionary a into config dictionary b, clobbering the options in b whenever they are also specified in a. """ assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict' assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict' for k, v_ in a.items(): full_key = '.'.join(stack) + '.' + k if stack is not None else k # a must specify keys that are in b if k not in b: raise KeyError('Non-existent config key: {}'.format(full_key)) v = copy.deepcopy(v_) v = _decode_cfg_value(v) v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) # Recursively merge dicts if isinstance(v, AttrDict): try: stack_push = [k] if stack is None else stack + [k] _merge_a_into_b(v, b[k], stack=stack_push) except BaseException: raise else: b[k] = v def _decode_cfg_value(v): """Decodes a raw config value (e.g., from a yaml config files or command line argument) into a Python object. """ # Configs parsed from raw yaml will contain dictionary keys that need to be # converted to AttrDict objects if isinstance(v, dict): return AttrDict(v) # All remaining processing is only applied to strings if not isinstance(v, six.string_types): return v # Try to interpret `v` as a: # string, number, tuple, list, dict, boolean, or None try: v = literal_eval(v) # The following two excepts allow v to pass through when it represents a # string. # # Longer explanation: # The type of v is always a string (before calling literal_eval), but # sometimes it *represents* a string and other times a data structure, like # a list. In the case that v represents a string, what we got back from the # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is # ok with '"foo"', but will raise a ValueError if given 'foo'. In other # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval # will raise a SyntaxError. except ValueError: pass except SyntaxError: pass return v def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key): """Checks that `value_a`, which is intended to replace `value_b` is of the right type. The type is correct if it matches exactly or is one of a few cases in which the type can be easily coerced. """ # The types must match (with some exceptions) type_b = type(value_b) type_a = type(value_a) if type_a is type_b: return value_a # Exceptions: numpy arrays, strings, tuple<->list if isinstance(value_b, np.ndarray): value_a = np.array(value_a, dtype=value_b.dtype) elif isinstance(value_b, six.string_types): value_a = str(value_a) elif isinstance(value_a, tuple) and isinstance(value_b, list): value_a = list(value_a) elif isinstance(value_a, list) and isinstance(value_b, tuple): value_a = tuple(value_a) else: raise ValueError( 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config ' 'key: {}'.format(type_b, type_a, value_b, value_a, full_key) ) return value_a ================================================ FILE: datasets/__init__.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import torch from torch.utils import data from .dataloader_seg import DataSeg from .dataloader_video import DataVideo, DataVideoKinetics def get_sets(task): # fetch the names of data lists depending on the task sets = {} if task == "OxUvA": sets["train_video"] = "train_oxuva" sets["val_video"] = "val_davis2017_480p" sets["val_video_seg"] = "val2_davis2017_480p" elif task == "YTVOS": sets["train_video"] = "train_ytvos" sets["val_video"] = "val_davis2017_480p" sets["val_video_seg"] = "val2_davis2017_480p" elif task == "TrackingNet": sets["train_video"] = "train_tracking" sets["val_video"] = "val_davis2017_480p" sets["val_video_seg"] = "val2_davis2017_480p" elif task == "kinetics400": sets["train_video"] = "train_kinetics400" sets["val_video"] = "val_davis2017_480p" sets["val_video_seg"] = "val2_davis2017_480p" else: raise NotImplementedError("Dataset '{}' not recognised.".format(task)) return sets def get_dataloader(args, cfg, split): assert split in ("train", "val") task = cfg.TRAIN.TASK data_sets = get_sets(cfg.TRAIN.TASK) # total batch size: # of GPUs * batch size per GPU batch_size = cfg.TRAIN.BATCH_SIZE kwargs = {'pin_memory': True, 'num_workers': args.workers} print("Dataloader: # workers {}".format(args.workers)) def _dataloader(dataset, batch_size, shuffle=True, drop_last=False): return data.DataLoader(dataset, batch_size=batch_size, \ shuffle=shuffle, drop_last=drop_last, **kwargs) print("Split: ", split) if split == "train": VideoLoader = DataVideoKinetics if cfg.DATASET.MP4 else DataVideo dataset_video = VideoLoader(cfg, data_sets["train_video"]) return _dataloader(dataset_video, batch_size, drop_last=True) else: dataset_video = DataVideo(cfg, data_sets["val_video"], val=True) dataset_video_seg = DataSeg(cfg, data_sets["val_video_seg"]) return {"val_video": _dataloader(dataset_video, batch_size, shuffle=False), \ "val_video_seg": _dataloader(dataset_video_seg, 1, shuffle=False)} ================================================ FILE: datasets/dataloader_base.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import torch.utils.data as data from utils.palette import custom_palette class DLBase(data.Dataset): def __init__(self, *args, **kwargs): super(DLBase, self).__init__(*args, **kwargs) # RGB self.MEAN = [0.485, 0.456, 0.406] self.STD = [0.229, 0.224, 0.225] self._init_means() def _init_means(self): self.MEAN255 = [255.*x for x in self.MEAN] self.STD255 = [255.*x for x in self.STD] def _init_palette(self, num_classes): self.palette = custom_palette(num_classes) def get_palette(self): return self.palette def remove_labels(self, mask): # Remove labels not in training for ignore_label in self.ignore_labels: mask[mask == ignore_label] = 255 return mask.long() ================================================ FILE: datasets/dataloader_infer.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os import torch from PIL import Image import numpy as np import torchvision.transforms as tf from .dataloader_base import DLBase class DataSeg(DLBase): def __init__(self, cfg, split, ignore_labels=[], \ root=os.path.expanduser('./data'), renorm=False): super(DataSeg, self).__init__() self.cfg = cfg self.root = root self.split = split self.ignore_labels = ignore_labels self._init_palette(self.cfg.DATASET.NUM_CLASSES) # train/val/test splits are pre-cut split_fn = os.path.join(self.root, self.split + ".txt") assert os.path.isfile(split_fn) self.sequence_ids = [] self.sequence_names = [] def add_sequence(name): vlen = len(self.images) assert vlen >= cfg.DATASET.VIDEO_LEN, \ "Detected video shorter [{}] than training length [{}]".format(vlen, \ cfg.DATASET.VIDEO_LEN) self.sequence_ids.append(vlen) self.sequence_names.append(name) return vlen self.images = [] self.masks = [] self.flags = [] token = None with open(split_fn, "r") as lines: for line in lines: _flag, _image, _mask = line.strip("\n").split(' ') # save every frame #_flag = 1 self.flags.append(int(_flag)) _image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/')) assert os.path.isfile(_image), '%s not found' % _image # each sequence may have a different length # do some book-keeping e.g. to ensure we have # sequences long enough for subsequent sampling _token = _image.split("/")[-2] # parent directory # sequence ID is in the filename #_token = os.path.basename(_image).split("_")[0] if token != _token: if not token is None: add_sequence(token) token = _token self.images.append(_image) if _mask is None: self.masks.append(None) else: _mask = os.path.join(cfg.DATASET.ROOT, _mask.lstrip('/')) #assert os.path.isfile(_mask), '%s not found' % _mask self.masks.append(_mask) # update the last sequence # returns the total amount of frames add_sequence(token) print("Loaded {} sequences".format(len(self.sequence_ids))) # definint data augmentation: print("Dataloader: {}".format(split), " #", len(self.images)) print("\t {}: no augmentation".format(split)) self.tf = tf.Compose([tf.ToTensor(), tf.Normalize(mean=self.MEAN, std=self.STD)]) self._num_samples = len(self.images) def __len__(self): return len(self.sequence_ids) def _mask2tensor(self, mask, num_classes=6): h,w = mask.shape ones = torch.ones(1,h,w) zeros = torch.zeros(num_classes,h,w) max_idx = mask.max() assert max_idx < num_classes, "{} >= {}".format(max_idx, num_classes) return zeros.scatter(0, mask[None, ...], ones) def denorm(self, image): if image.dim() == 3: assert image.dim() == 3, "Expected image [CxHxW]" assert image.size(0) == 3, "Expected RGB image [3xHxW]" for t, m, s in zip(image, self.MEAN, self.STD): t.mul_(s).add_(m) elif image.dim() == 4: # batch mode assert image.size(1) == 3, "Expected RGB image [3xHxW]" for t, m, s in zip((0,1,2), self.MEAN, self.STD): image[:, t, :, :].mul_(s).add_(m) return image def __getitem__(self, index): seq_to = self.sequence_ids[index] seq_from = 0 if index == 0 else self.sequence_ids[index - 1] image0 = Image.open(self.images[seq_from]) w,h = image0.size images, masks, fns, flags = [], [], [], [] tracks = torch.LongTensor(self.cfg.DATASET.NUM_CLASSES).fill_(-1) masks = torch.LongTensor(self.cfg.DATASET.NUM_CLASSES, h, w).zero_() known_ids = set() for t in range(seq_from, seq_to): t0 = t - seq_from image = Image.open(self.images[t]).convert('RGB') fns.append(os.path.basename(self.images[t].replace(".jpg", ""))) flags.append(self.flags[t]) if os.path.isfile(self.masks[t]): mask = Image.open(self.masks[t]) mask = torch.from_numpy(np.array(mask, np.long, copy=False)) unique_ids = np.unique(mask) for oid in unique_ids: if not oid in known_ids: tracks[oid] = t0 known_ids.add(oid) masks[oid] = (mask == oid).long() else: mask = Image.new('L', image.size) image = self.tf(image) images.append(image) images = torch.stack(images, 0) seq_name = self.sequence_names[index] flags = torch.LongTensor(flags) return images, images, masks, tracks, len(known_ids), fns, flags, seq_name ================================================ FILE: datasets/dataloader_seg.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os import torch from PIL import Image import numpy as np import torch.utils.data as data import torchvision.transforms as tf from .dataloader_base import DLBase class DataSeg(DLBase): def __init__(self, cfg, split, ignore_labels=[], \ root=os.path.expanduser('./data'), renorm=False): super(DataSeg, self).__init__() self.cfg = cfg self.root = root self.split = split self.ignore_labels = ignore_labels self._init_palette(cfg.DATASET.NUM_CLASSES) # train/val/test splits are pre-cut split_fn = os.path.join(self.root, "filelists", self.split + ".txt") assert os.path.isfile(split_fn) self.sequence_ids = [] self.sequence_names = [] def add_sequence(name): vlen = len(self.images) assert vlen >= cfg.DATASET.VIDEO_LEN, \ "Detected video shorter [{}] than training length [{}]".format(vlen, \ cfg.DATASET.VIDEO_LEN) self.sequence_ids.append(vlen) self.sequence_names.append(name) return vlen self.images = [] self.masks = [] token = None with open(split_fn, "r") as lines: for line in lines: _image = line.strip("\n").split(' ') _mask = None if len(_image) == 2: _image, _mask = _image else: assert len(_image) == 1 _image = _image[0] _image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/')) assert os.path.isfile(_image), '%s not found' % _image self.images.append(_image) # each sequence may have a different length # do some book-keeping e.g. to ensure we have # sequences long enough for subsequent sampling _token = _image.split("/")[-2] # parent directory # sequence ID is in the filename #_token = os.path.basename(_image).split("_")[0] if token != _token: if not token is None: add_sequence(token) token = _token if _mask is None: self.masks.append(None) else: _mask = os.path.join(cfg.DATASET.ROOT, _mask.lstrip('/')) assert os.path.isfile(_mask), '%s not found' % _mask self.masks.append(_mask) # update the last sequence # returns the total amount of frames add_sequence(token) print("Loaded {} sequences".format(len(self.sequence_ids))) # definint data augmentation: print("Dataloader: {}".format(split), " #", len(self.images)) print("\t {}: no augmentation".format(split)) self.tf = tf.Compose([tf.ToTensor(), tf.Normalize(mean=self.MEAN, std=self.STD)]) self._num_samples = len(self.images) def __len__(self): return len(self.sequence_ids) def denorm(self, image): if image.dim() == 3: assert image.dim() == 3, "Expected image [CxHxW]" assert image.size(0) == 3, "Expected RGB image [3xHxW]" for t, m, s in zip(image, self.MEAN, self.STD): t.mul_(s).add_(m) elif image.dim() == 4: # batch mode assert image.size(1) == 3, "Expected RGB image [3xHxW]" for t, m, s in zip((0,1,2), self.MEAN, self.STD): image[:, t, :, :].mul_(s).add_(m) return image def _mask2tensor(self, mask, num_classes=6): h,w = mask.shape ones = torch.ones(1,h,w) zeros = torch.zeros(num_classes,h,w) max_idx = mask.max() assert max_idx < num_classes, "{} >= {}".format(max_idx, num_classes) return zeros.scatter(0, mask[None, ...], ones) def __getitem__(self, index): seq_to = self.sequence_ids[index] - 1 seq_from = 0 if index == 0 else self.sequence_ids[index-1] - 1 images, masks = [], [] n_obj = 0 for _id_ in range(seq_from, seq_to): image = Image.open(self.images[_id_]).convert('RGB') if self.masks[_id_] is None: mask = Image.new('L', image.size) else: mask = Image.open(self.masks[_id_]) #.convert('L') image = self.tf(image) images.append(image) mask = torch.from_numpy(np.array(mask, np.long, copy=False)) n_obj = max(n_obj, mask.max().item()) masks.append(self._mask2tensor(mask)) images = torch.stack(images, 0) masks = torch.stack(masks, 0) n_obj = torch.LongTensor([n_obj + 1]) # +1 background seq_name = self.sequence_names[index] return images, masks, n_obj, seq_name ================================================ FILE: datasets/dataloader_video.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os import random import torch import math import glob from PIL import Image import numpy as np from .dataloader_base import DLBase import datasets.daugm_video as tf class DataVideo(DLBase): def __init__(self, cfg, split, val=False): super(DataVideo, self).__init__() self.cfg = cfg self.split = split self.val = val self.cfg_frame_gap = cfg.DATASET.VAL_FRAME_GAP if val else cfg.DATASET.FRAME_GAP self._init_palette(cfg.TRAIN.BATCH_SIZE * cfg.MODEL.GRID_SIZE**2) # train/val/test splits are pre-cut split_fn = os.path.join(cfg.DATASET.ROOT, "filelists", self.split + ".txt") assert os.path.isfile(split_fn) self.images = [] token = None # sequence token (new video when it changes) subsequence = [] ignored = [0] num_frames = [0] def add_sequence(): if cfg.DATASET.VIDEO_LEN > len(subsequence): # found a very short sequence ignored[0] += 1 else: # adding the subsequence self.images.append(tuple(subsequence)) num_frames[0] += len(subsequence) del subsequence[:] with open(split_fn, "r") as lines: for line in lines: _line = line.strip("\n").split(' ') assert len(_line) > 0, "Expected at least one path" _image = _line[0] # each sequence may have a different length # do some book-keeping e.g. to ensure we have # sequences long enough for subsequent sampling _token = _image.split("/")[-2] # parent directory # sequence ID is in the filename #_token = os.path.basename(_image).split("_")[0] if token != _token: if not token is None: add_sequence() token = _token # image 1 _image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/')) #assert os.path.isfile(_image), '%s not found' % _image subsequence.append(_image) # update the last sequence # returns the total amount of frames add_sequence() print("Dataloader: {}".format(split), " / Frame Gap: ", self.cfg_frame_gap) print("Loaded {} sequences / {} ignored / {} frames".format(len(self.images), \ ignored[0], \ num_frames[0])) self._num_samples = num_frames[0] self._init_augm(cfg) def _init_augm(self, cfg): # general (unguided) affine transformations tfs_pre = [tf.CreateMask()] self.tf_pre = tf.Compose(tfs_pre) # photometric noise tfs_affine = [] # guided affine transformations tfs_augm = [] # 1. # general affine transformations # tfs_pre.append(tf.MaskScaleSmallest(cfg.DATASET.SMALLEST_RANGE)) if cfg.DATASET.RND_CROP: tfs_pre.append(tf.MaskRandCrop(cfg.DATASET.CROP_SIZE, pad_if_needed=True)) else: tfs_pre.append(tf.MaskCenterCrop(cfg.DATASET.CROP_SIZE)) if cfg.DATASET.RND_HFLIP: tfs_pre.append(tf.MaskRandHFlip()) # 2. # Guided affine transformation # if cfg.DATASET.GUIDED_HFLIP: tfs_affine.append(tf.GuidedRandHFlip()) # this will add affine transformation if cfg.DATASET.RND_ZOOM: tfs_affine.append(tf.MaskRandScaleCrop(*cfg.DATASET.RND_ZOOM_RANGE)) self.tf_affine = tf.Compose(tfs_affine) self.tf_affine2 = tf.Compose([tf.AffineIdentity()]) tfs_post = [tf.ToTensorMask(), tf.Normalize(mean=self.MEAN, std=self.STD), tf.ApplyMask(-1)] # image to the teacher will have no noise self.tf_post = tf.Compose(tfs_post) def set_num_samples(self, n): print("Re-setting # of samples: {:d} -> {:d}".format(self._num_samples, n)) self._num_samples = n def __len__(self): return len(self.images) #self._num_samples def denorm(self, image): if image.dim() == 3: assert image.dim() == 3, "Expected image [CxHxW]" assert image.size(0) == 3, "Expected RGB image [3xHxW]" for t, m, s in zip(image, self.MEAN, self.STD): t.mul_(s).add_(m) elif image.dim() == 4: # batch mode assert image.size(1) == 3, "Expected RGB image [3xHxW]" for t, m, s in zip((0,1,2), self.MEAN, self.STD): image[:, t, :, :].mul_(s).add_(m) return image def _get_affine(self, params): N = len(params) # construct affine operator affine = torch.zeros(N, 2, 3) aspect_ratio = float(self.cfg.DATASET.CROP_SIZE[0]) / \ float(self.cfg.DATASET.CROP_SIZE[1]) for i, (dy,dx,alpha,scale,flip) in enumerate(params): # R inverse sin = math.sin(alpha * math.pi / 180.) cos = math.cos(alpha * math.pi / 180.) # inverse, note how flipping is incorporated affine[i,0,0], affine[i,0,1] = flip * cos, sin * aspect_ratio affine[i,1,0], affine[i,1,1] = -sin / aspect_ratio, cos # T inverse Rinv * t == R^T * t affine[i,0,2] = -1. * (cos * dx + sin * dy) affine[i,1,2] = -1. * (-sin * dx + cos * dy) # T affine[i,0,2] /= float(self.cfg.DATASET.CROP_SIZE[1] // 2) affine[i,1,2] /= float(self.cfg.DATASET.CROP_SIZE[0] // 2) # scaling affine[i] *= scale return affine def _get_affine_inv(self, affine, params): aspect_ratio = float(self.cfg.DATASET.CROP_SIZE[0]) / \ float(self.cfg.DATASET.CROP_SIZE[1]) affine_inv = affine.clone() affine_inv[:,0,1] = affine[:,1,0] * aspect_ratio**2 affine_inv[:,1,0] = affine[:,0,1] / aspect_ratio**2 affine_inv[:,0,2] = -1 * (affine_inv[:,0,0] * affine[:,0,2] + affine_inv[:,0,1] * affine[:,1,2]) affine_inv[:,1,2] = -1 * (affine_inv[:,1,0] * affine[:,0,2] + affine_inv[:,1,1] * affine[:,1,2]) # scaling affine_inv /= torch.Tensor(params)[:,3].view(-1,1,1)**2 return affine_inv def __getitem__(self, index): # searching for the video clip ID sequence = self.images[index] # % len(self.images)] seqlen = len(sequence) assert self.cfg_frame_gap > 0, "Frame gap should be positive" t_window = self.cfg_frame_gap * self.cfg.DATASET.VIDEO_LEN # reduce sampling gap for short clips t_window = min(seqlen, t_window) frame_gap = t_window // self.cfg.DATASET.VIDEO_LEN # strided slice frame_ids = torch.arange(t_window)[::frame_gap] frame_ids = frame_ids[:self.cfg.DATASET.VIDEO_LEN] assert len(frame_ids) == self.cfg.DATASET.VIDEO_LEN # selecting a random start index_start = random.randint(0, seqlen - frame_ids[-1] - 1) # permuting the frames in the batch random_ids = torch.randperm(self.cfg.DATASET.VIDEO_LEN) # adding the offset frame_ids = frame_ids[random_ids] + index_start # forward sequence images = [] for frame_id in frame_ids: fn = sequence[frame_id] images.append(Image.open(fn).convert('RGB')) # 1. general transforms frames, valid = self.tf_pre(images) # 1.1 creating two sequences in forward/backward order frames1, valid1 = frames[:], valid[:] # second copy frames2 = [f.copy() for f in frames] valid2 = [v.copy() for v in valid] # 2. guided affine transforms frames1, valid1, affine_params1 = self.tf_affine(frames1, valid1) frames2, valid2, affine_params2 = self.tf_affine2(frames2, valid2) # convert to tensor, zero out the values frames1 = self.tf_post(frames1, valid1) frames2 = self.tf_post(frames2, valid2) # converting the affine transforms aff_reg = self._get_affine(affine_params1) aff_main = self._get_affine(affine_params2) aff_reg_inv = self._get_affine_inv(aff_reg, affine_params1) aff_reg = aff_main # identity affine2_inv aff_main = aff_reg_inv frames1 = torch.stack(frames1, 0) frames2 = torch.stack(frames2, 0) assert frames1.shape == frames2.shape return frames2, frames1, aff_main, aff_reg class DataVideoKinetics(DataVideo): def __init__(self, cfg, split): super(DataVideo, self).__init__() self.cfg = cfg self.split = split self._init_palette(cfg.TRAIN.BATCH_SIZE * cfg.MODEL.GRID_SIZE**2) # train/val/test splits are pre-cut split_fn = os.path.join(cfg.DATASET.ROOT, "filelists", self.split + ".txt") assert os.path.isfile(split_fn) self.videos = [] with open(split_fn, "r") as lines: for line in lines: _line = line.strip("\n").split(' ') assert len(_line) > 0, "Expected at least one path" _vid = _line[0] # image 1 _vid = os.path.join(cfg.DATASET.ROOT, _vid.lstrip('/')) #assert os.path.isdir(_vid), "{} does not exist".format(_vid) self.videos.append(_vid) # update the last sequence # returns the total amount of frames print("DataloaderKinetics: {}".format(split)) print("Loaded {} sequences".format(len(self.videos))) self._init_augm(cfg) def __len__(self): return len(self.videos) def __getitem__(self, index): C = self.cfg.DATASET path = self.videos[index] # filenames fns = sorted(glob.glob(path + "/*.jpg")) total_len = len(fns) # temporal window to consider temp_window = min(total_len, C.VIDEO_LEN * C.FRAME_GAP) gap = temp_window / C.VIDEO_LEN start_frame = random.randint(0, total_len - temp_window) images = [] for idx in range(C.VIDEO_LEN): frame_id = start_frame + int(idx * gap) fn = fns[frame_id] images.append(Image.open(fn).convert('RGB')) # 1. general transforms frames, valid = self.tf_pre(images) # 1.1 creating two sequences in forward/backward order frames1, valid1 = frames[:], valid[:] # second copy frames2 = [f.copy() for f in frames] valid2 = [v.copy() for v in valid] # 2. guided affine transforms frames1, valid1, affine_params1 = self.tf_affine(frames1, valid1) frames2, valid2, affine_params2 = self.tf_affine2(frames2, valid2) # convert to tensor, zero out the values frames1 = self.tf_post(frames1, valid1) frames2 = self.tf_post(frames2, valid2) # converting the affine transforms affine1 = self._get_affine(affine_params1) affine2 = self._get_affine(affine_params2) affine1_inv = self._get_affine_inv(affine1, affine_params1) affine1 = affine2 # identity affine2_inv affine2 = affine1_inv frames1 = torch.stack(frames1, 0) frames2 = torch.stack(frames2, 0) assert frames1.shape == frames2.shape return frames2, frames1, affine2, affine1 ================================================ FILE: datasets/daugm_video.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import random import numpy as np import torch from PIL import Image import torchvision.transforms as tf import torchvision.transforms.functional as F class Compose: # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) def __init__(self, segtransform): self.segtransform = segtransform def __call__(self, args, *more_args): # allow for intermediate representations for t in self.segtransform: result = t(args, *more_args) args = result[0] more_args = result[1:] return result class ToTensorMask: def __toByteTensor(self, pic): return torch.from_numpy(np.array(pic, np.int32, copy=False)) def __call__(self, images, masks): new_masks = [] for i, (image, mask) in enumerate(zip(images, masks)): images[i] = F.to_tensor(image) new_masks.append(self.__toByteTensor(mask)) return images, new_masks class CreateMask: """Create mask to hold invalid pixels (e.g. from rotations or downscaling) """ def __call__(self, images): masks = [] for i, image in enumerate(images): masks.append(Image.new("L", image.size)) return images, masks class Normalize: # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std def __init__(self, mean, std=None): if std is None: assert len(mean) > 0 else: assert len(mean) == len(std) self.mean = mean self.std = std def __call__(self, images, masks): for i, (image, mask) in enumerate(zip(images, masks)): if self.std is None: for t, m in zip(image, self.mean): t.sub_(m) else: for t, m, s in zip(image, self.mean, self.std): t.sub_(m).div_(s) return images, masks class ApplyMask: def __init__(self, ignore_label): self.ignore_label = ignore_label def __call__(self, images, masks): for i, (image, mask) in enumerate(zip(images, masks)): mask = mask > 0. images[i] *= (1. - mask.type_as(image)) return images class GuidedRandHFlip: def __call__(self, images, mask, affine=None): if affine is None: affine = [[0.,0.,0.,1.,1.] for _ in images] if random.random() > 0.5: for i, image in enumerate(images): affine[i][4] *= -1 images[i] = F.hflip(image) return images, mask, affine class AffineIdentity(object): def __call__(self, images, masks, affine=None): if affine is None: affine = [[0.,0.,0.,1.,1.] for _ in images] return images, masks, affine class MaskRandScaleCrop(object): def __init__(self, scale_from, scale_to): self.scale_from = scale_from self.scale_to = scale_to #assert scale_from >= 1., "Zooming in is not supported yet" def get_scale(self): return random.uniform(self.scale_from, self.scale_to) def get_params(self, h, w, new_scale): # generating random crop # preserves aspect ratio new_h = int(new_scale * h) new_w = int(new_scale * w) # generating if new_scale <= 1.: assert w >= new_w and h >= new_h, "{} vs. {} | {} / {}".format(w, new_w, h, new_h) i = random.randint(0, h - new_h) j = random.randint(0, w - new_w) else: assert w <= new_w and h <= new_h, "{} vs. {} | {} / {}".format(w, new_w, h, new_h) i = random.randint(h - new_h, 0) j = random.randint(w - new_w, 0) return i, j, new_h, new_w def __call__(self, images, masks, affine=None): if affine is None: affine = [[0.,0.,0.,1.,1.] for _ in images] W, H = images[0].size i2 = H / 2 j2 = W / 2 masks_new = [] # one crop for all s = self.get_scale() ii, jj, h, w = self.get_params(H, W, s) # displacement of the centre dy = ii + h / 2 - i2 dx = jj + w / 2 - j2 for k, image in enumerate(images): affine[k][0] = dy affine[k][1] = dx affine[k][3] = 1 / s # scale if s <= 1.: assert ii >= 0 and jj >= 0 # zooming in image_crop = F.crop(image, ii, jj, h, w) images[k] = image_crop.resize((W, H), Image.BILINEAR) mask_crop = F.crop(masks[k], ii, jj, h, w) masks_new.append(mask_crop.resize((W, H), Image.NEAREST)) else: assert ii <= 0 and jj <= 0 # zooming out pad_l = abs(jj) pad_r = w - W - pad_l pad_t = abs(ii) pad_b = h - H - pad_t image_pad = F.pad(image, (pad_l, pad_t, pad_r, pad_b)) images[k] = image_pad.resize((W, H), Image.BILINEAR) mask_pad = F.pad(masks[k], (pad_l, pad_t, pad_r, pad_b), 1) masks_new.append(mask_pad.resize((W, H), Image.NEAREST)) return images, masks, affine class MaskScaleSmallest(object): def __init__(self, smallest_range): self.size = smallest_range def __call__(self, images, masks): assert len(images) > 0, "Non-empty array expected" new_size = self.size[0] + int((self.size[1] - self.size[0]) * random.random()) w, h = images[0].size aspect = w / h if aspect > 1: new_h = new_size new_w = int(new_size * aspect) else: new_w = new_size new_h = int(new_size / aspect) new_size = (new_w, new_h) for i, (image, mask) in enumerate(zip(images, masks)): assert image.size == mask.size assert image.size == (w, h) images[i] = image.resize(new_size, Image.BILINEAR) masks[i] = mask.resize(new_size, Image.NEAREST) return images, masks class MaskRandCrop: def __init__(self, size, pad_if_needed=False): self.size = size # (h, w) self.pad_if_needed = pad_if_needed def __pad(self, img, padding_mode='constant', fill=0): # pad the width if needed pad_width = self.size[1] - img.size[0] pad_height = self.size[0] - img.size[1] if self.pad_if_needed and (pad_width > 0 or pad_height > 0): pad_l = max(0, pad_width // 2) pad_r = max(0, pad_width - pad_l) pad_t = max(0, pad_height // 2) pad_b = max(0, pad_height - pad_t) img = F.pad(img, (pad_l, pad_t, pad_r, pad_b), fill, padding_mode) return img def __call__(self, images, masks): for i, (image, mask) in enumerate(zip(images, masks)): images[i] = self.__pad(image) masks[i] = self.__pad(mask, fill=1) i, j, h, w = tf.RandomCrop.get_params(images[0], self.size) for k, (image, mask) in enumerate(zip(images, masks)): images[k] = F.crop(image, i, j, h, w) masks[k] = F.crop(mask, i, j, h, w) return images, masks class MaskCenterCrop: def __init__(self, size): self.size = size # (h, w) def __call__(self, images, masks): for i, (image, mask) in enumerate(zip(images, masks)): images[i] = F.center_crop(image, self.size) masks[i] = F.center_crop(mask, self.size) return images, masks class MaskRandHFlip: def __call__(self, images, masks): if random.random() > 0.5: for i, (image, mask) in enumerate(zip(images, masks)): images[i] = F.hflip(image) masks[i] = F.hflip(mask) return images, masks ================================================ FILE: infer_vos.py ================================================ """ Single-scale inference Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os import sys import numpy as np import imageio import time import torch.multiprocessing as mp from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from opts import get_arguments from core.config import cfg, cfg_from_file, cfg_from_list from models import get_model from utils.timer import Timer from utils.sys_tools import check_dir from utils.palette_davis import palette as davis_palette from torch.utils.data import DataLoader from datasets.dataloader_infer import DataSeg from labelprop.common import LabelPropVOS_CRW # deterministic inference from torch.backends import cudnn cudnn.enabled = True cudnn.benchmark = False cudnn.deterministic = True VERBOSE = True def mask2rgb(mask, palette): mask_rgb = palette(mask) mask_rgb = mask_rgb[:,:,:3] return mask_rgb def mask_overlay(mask, image, palette): """Creates an overlayed mask visualisation""" mask_rgb = mask2rgb(mask, palette) return 0.3 * image + 0.7 * mask_rgb class ResultWriter: def __init__(self, key, palette, out_path): self.key = key self.palette = palette self.out_path = out_path self.verbose = VERBOSE def save(self, frames, masks_pred, masks_conf, masks_gt, flags, fn, seq_name): subdir_vos = os.path.join(self.out_path, "{}_vos".format(self.key)) check_dir(subdir_vos, seq_name) subdir_vis = os.path.join(self.out_path, "{}_vis".format(self.key)) check_dir(subdir_vis, seq_name) for frame_id, mask in enumerate(masks_pred.split(1, 0)): mask = mask[0].numpy().astype(np.uint8) filepath = os.path.join(subdir_vos, seq_name, "{}.png".format(fn[frame_id][0])) # saving only every 5th frame if flags[frame_id] != 0: imageio.imwrite(filepath, mask) if self.verbose: frame = frames[frame_id].numpy() #mask_gt = masks_gt[frame_id].numpy().astype(np.uint8) #masks = np.concatenate([mask, mask_gt], 1) #frame = np.concatenate([frame, frame], 2) frame = np.transpose(frame, [1,2,0]) overlay = mask_overlay(mask, frame, self.palette) filepath = os.path.join(subdir_vis, seq_name, "{}.png".format(fn[frame_id][0])) imageio.imwrite(filepath, (overlay * 255.).astype(np.uint8)) def convert_dict(state_dict): new_dict = {} for k,v in state_dict.items(): new_key = k.replace("module.", "") new_dict[new_key] = v return new_dict def mask2tensor(mask, idx, num_classes=cfg.DATASET.NUM_CLASSES): h,w = mask.shape mask_t = torch.zeros(1,num_classes,h,w) mask_t[0, idx] = mask return mask_t def configure_tracks(masks_gt, tracks, num_objects): """Selecting masks for initialisation Args: masks_gt: [T,H,W] tracks: [T,2] """ init_masks = {} # we always have first mask # if there are no instances, it will be simply zero H,W = masks_gt[0].shape[-2:] init_masks[0] = torch.zeros(1, cfg.DATASET.NUM_CLASSES, H, W) for oid in range(cfg.DATASET.NUM_CLASSES): t = tracks[oid].item() if not t in init_masks: init_masks[t] = mask2tensor(masks_gt[oid], oid) else: init_masks[t] += mask2tensor(masks_gt[oid], oid) return init_masks def make_onehot(mask, HW): # convert mask tensor with probabilities to a one-hot tensor b,c,h,w = mask.shape mask_up = F.interpolate(mask, HW, mode="bilinear", align_corners=True) one_hot = torch.zeros_like(mask_up) one_hot.scatter_(1, mask_up.argmax(1, keepdim=True), 1) one_hot = F.interpolate(one_hot, (h,w), mode="bilinear", align_corners=True) return one_hot def scale_smallest(frame, a): H,W = frame.shape[-2:] s = a / min(H, W) h, w = int(s * H), int(s * W) return F.interpolate(frame, (h, w), mode="bilinear", align_corners=True) def valid_mask(mask): """From a tensor [1,C,h,w] create [1,C,1,1] 0/1 mask saying which IDs are present""" B,C,h,w = mask.shape valid = mask.flatten(2,3).sum(-1) > 0 valid = valid.type_as(mask).view(B,C,1,1) return valid def merge_mask_ids(masks, key0): merged_mask = torch.zeros_like(masks[key0]) for tt, mask in masks.items(): merged_mask[:,1:] += mask[:,1:] probs, ids = merged_mask.max(1, keepdim=True) merged_mask.zero_() merged_mask.scatter(1, ids, probs) merged_mask[:, :1] = 1 - probs return merged_mask def step_seg(cfg, net, labelprop, frames, mask_init): # dense tracking: start from the 1st frame # keep track of new objects T = frames.shape[0] frames = frames.cuda() # scale smallest if cfg.TEST.INPUT_SIZE > 0: frames = scale_smallest(frames, cfg.TEST.INPUT_SIZE) fetch = {"res3": lambda x: x[0], \ "res4": lambda x: x[1], \ "key": lambda x: x[2]} for t in mask_init.keys(): mask_init[t] = mask_init[t].cuda() H,W = mask_init[0].shape[-2:] scale_as = lambda x, y: F.interpolate(x, y.shape[-2:], mode="bilinear", align_corners=True) scale = lambda x, hw: F.interpolate(x, hw, mode="bilinear", align_corners=True) # context (cxt) will maintain # the reference frame ref_embd = {} # context embeddings ref_masks = {} ref_valid = {} # we will also keep a bunch # of previous frames prev_embd = None prev_masks = None all_masks = [] all_masks_conf = [] def add_result(mask): mask_up = scale(mask, (H, W)) nxt_masks_conf, nxt_masks_id = mask_up.max(1) all_masks.append(nxt_masks_id.cpu()) all_masks_conf.append(nxt_masks_conf.cpu()) # initialising embd0 = net(frames[:1], embd_only=True, norm=True) embd0 = fetch[cfg.TEST.KEY](embd0) mask0 = scale_as(mask_init[0], embd0) add_result(mask0) ref_embd[0] = {0: embd0} ref_masks[0] = {0: mask0} ref_valid[0] = valid_mask(mask0) # [x,c,1,1] # add this to the reference context # if there are objects ref_index = [] if mask_init[0].sum() > 0: ref_index = [0] print(">", end='') for t in range(1, T): print(".", end='') sys.stdout.flush() # next frame frames_batch = frames[t:t+1] # source forward pass nxt_embds = net(frames_batch, embd_only=True, norm=True) # fetching the feature nxt_embd = fetch[cfg.TEST.KEY](nxt_embds) ref_t = [0] if len(ref_index) == 0 else ref_index # for each reference mask # we will create own context, then # propagate the labels and merge the result nxt_masks = {} for t0 in ref_t: cxt_index = labelprop.context_index(t0, t) cxt_embd = [ref_embd[t0][j] for j in cxt_index] cxt_masks = [ref_masks[t0][j] for j in cxt_index] nxt_masks[t0] = labelprop.predict(cxt_embd, cxt_masks, nxt_embd, cxt_index, t) # merging all the masks nxt_mask = sum([ref_valid[tt] * nxt_masks[tt] for tt in nxt_masks.keys()]) #nxt_mask = merge_mask_ids(nxt_masks, ref_t[0]) if t in mask_init: # not t >= 0 print("Adding GT mask t = ", t) # adding the initial mask if just appeared mask_init_dn = scale_as(mask_init[t], nxt_embd) mask_init_dn_s = mask_init_dn.sum(1, keepdim=True) nxt_mask = (1 - mask_init_dn_s) * nxt_mask + mask_init_dn_s * mask_init_dn # adding to context ref_embd[t] = {} ref_masks[t] = {} ref_valid[t] = valid_mask(mask_init[t]) ref_index.append(t) add_result(nxt_mask) ref_t = [0] if len(ref_index) == 0 else ref_index # # updating the context # two parts: for every initial index, keep first N and last M frames for t0 in ref_t: ref_embd[t0][t] = nxt_embd.clone() ref_masks[t0][t] = nxt_mask.clone() index_short = labelprop.context_long(t0, t) tsteps = list(ref_embd[t0].keys()) for tt in tsteps: if t - tt > cfg.TEST.CXT_SIZE and not tt in index_short: del ref_embd[t0][tt] del ref_masks[t0][tt] print('<') masks_pred = torch.cat(all_masks, 0) masks_pred_conf = torch.cat(all_masks_conf, 0) return masks_pred, masks_pred_conf if __name__ == '__main__': # loading the model args = get_arguments(sys.argv[1:]) # reading the config cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) # initialising the dirs check_dir(args.mask_output_dir, "{}_vis".format(cfg.TEST.KEY)) check_dir(args.mask_output_dir, "{}_vos".format(cfg.TEST.KEY)) # Loading the model model = get_model(cfg, remove_layers=cfg.MODEL.REMOVE_LAYERS) labelprop = LabelPropVOS_CRW(cfg) if not os.path.isfile(args.resume): print("[W]: ", "Snapshot not found: {}".format(args.resume)) print("[W]: Using a random model") else: state_dict = convert_dict(torch.load(args.resume)["model"]) try: model.load_state_dict(state_dict) except Exception as e: print("Error while loading the snapshot:\n", str(e)) print("Resuming...") model.load_state_dict(state_dict, strict=False) for p in model.parameters(): p.requires_grad = False # setting the evaluation mode model.eval() model = model.cuda() dataset = DataSeg(cfg, args.infer_list) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, \ drop_last=False) #, num_workers=args.workers) palette = dataloader.dataset.get_palette() timer = Timer() N = len(dataloader) pool = mp.Pool(processes=args.workers) writer = ResultWriter(cfg.TEST.KEY, davis_palette, args.mask_output_dir) for iter, batch in enumerate(dataloader): frames_orig, frames, masks_gt, tracks, num_ids, fns, flags, seq_name = batch print("Sequence {:02d} | {}".format(iter, seq_name[0])) masks_gt = masks_gt.flatten(0,1) frames_orig = frames_orig.flatten(0,1) frames = frames.flatten(0,1) tracks = tracks.flatten(0,1) flags = flags.flatten(0,1) init_masks = configure_tracks(masks_gt, tracks, num_ids[0]) assert 0 in init_masks, "initial frame has no instances" with torch.no_grad(): masks_pred, masks_conf = step_seg(cfg, model, labelprop, frames, init_masks) frames_orig = dataset.denorm(frames_orig) pool.apply_async(writer.save, args=(frames_orig, masks_pred.cpu(), masks_conf.cpu(), masks_gt.cpu(), flags, fns, seq_name[0])) timer.stage("Inference completed") pool.close() pool.join() ================================================ FILE: labelprop/common.py ================================================ """ Based on the inference routines from Jabri et al., (2020) Credit: https://github.com/ajabri/videowalk.git License: MIT """ import sys import torch from labelprop.crw import MaskedAttention from labelprop.crw import mem_efficient_batched_affinity as batched_affinity class LabelPropVOS(object): def context_long(self): """Returns indices of the timesteps for long-term memory """ raise NotImplementedError() def context_short(self, t): """ Args: t: current timestep Returns: list: indices of timesteps for the context """ raise NotImplementedError() def predict(self, feats, masks, curr_feat): """ Args: feats [C,K,h,w]: context features masks [C,M,h,w]: context masks curr_feat [1,K,h,w]: current frame features Returns: mask [1,M,h,w]: current frame mask """ raise NotImplementedError() class LabelPropVOS_CRW(LabelPropVOS): def __init__(self, cfg): self.cxt_size = cfg.TEST.CXT_SIZE self.radius = cfg.TEST.RADIUS self.temperature = cfg.TEST.TEMP self.topk = cfg.TEST.KNN self.mask = None self.mask_hw = None def context_long(self, t0, t): return [t0] def context_short(self, t0, t): to_t = t from_t = to_t - self.cxt_size timesteps = [max(tt, t0) for tt in range(from_t, to_t)] return timesteps def context_index(self, t0, t): index_short = self.context_short(t0, t) index_long = self.context_long(t0, t) cxt_index = index_long + index_short return cxt_index def predict(self, feats, masks, curr_feat, ref_index=None, t=None): """ Args: feats: list of C [1,K,h,w] context features masks: list of C [1,M,h,w] context masks curr_feat: [1,K,h,w] current frame features ref_index: C indices of context frames t: current frame time step Returns: mask [1,M,h,w]: current frame mask """ dev = curr_feat.device h, w = curr_feat.shape[-2:] # [BC+N,M,h,w] -> [BC+N,h,w,M] ctx_lbls = torch.cat(masks, 0).permute([0,2,3,1]) # keys [1,K,C,h,w]: context features keys = torch.stack(feats, 2)[:, :, None] # query: [1,K,1,h,w]: reference feature query = curr_feat[:, :, None] if self.mask is None or self.mask_hw != (h, w): # Make spatial radius mask TODO use torch.sparse restrict = MaskedAttention(self.radius, flat=False) D = restrict.mask(h, w)[None] D = D.flatten(-4, -3).flatten(-2) D[D==0] = -1e10; D[D==1] = 0 self.mask = D.to(dev) self.mask_hw = (h, w) # Flatten source frame features to make context feature set keys, query = keys.flatten(-2), query.flatten(-2) long_mem = [0] Ws, Is = batched_affinity(query, keys, self.mask, \ self.temperature, self.topk, long_mem) # Soft labels of source nodes ctx_lbls = ctx_lbls.flatten(0, 2).transpose(0, 1) # Weighted sum of top-k neighbours (Is is index, Ws is weight) pred = (ctx_lbls[:, Is[0].to(dev)] * Ws[0][None].to(dev)).sum(1) pred = pred.view(-1, h, w) pred = pred.permute(1,2,0) # Adding Predictions pred = pred.permute([2,0,1])[None, ...] return pred ================================================ FILE: labelprop/crw.py ================================================ """ Inference routines from Jabri et al., (2020) Credit: https://github.com/ajabri/videowalk.git License: MIT """ import torch import torch.nn as nn import torch.nn.functional as F import time class CRW(object): """Propagation algorithm""" def __init__(self, cfg): self.n_context = cfg.CXT_SIZE self.radius = cfg.RADIUS self.temperature = cfg.TEMP self.topk = cfg.KNN print("Inference Opts:") print("Context size: {}".format(self.n_context)) print(" Radius: {}".format(self.radius)) print(" Temp: {}".format(self.temperature)) print(" TopK: {}".format(self.topk)) # always keeping the first frame # TODO: move to cfg self.long_mem = [0] # for bwd-compatibility self.norm_mask = False self.mask = None self.mask_hw = None def _prep_context(self, feats, lbls, hw): """Adjust for context Args: lbls: [N,M,H,W] """ lbls = F.interpolate(lbls, hw, mode="bilinear", align_corners=True) ref = lbls[:1].expand(self.n_context,-1,-1,-1) lbls = torch.cat([ref, lbls], 0) fref = feats[:1].expand(self.n_context,-1,-1,-1) fref = torch.cat([fref, feats], 0) return fref, lbls def forward(self, feats, lbls): """Propagate features Args: feats: [N,K,h,w] lbls: [N,M,H,W] """ N,K,h,w = feats.shape M,H,W = lbls.shape[-3:] feats, lbls = self._prep_context(feats, lbls, (h,w)) # [N,K,h,w] -> [1,K,N,h,w] feats = feats.permute([1,0,2,3]) # singleton for compatibility feats = feats[None,...] # [BC+N,M,h,w] -> [BC+N,h,w,M] lbls = lbls.permute([0,2,3,1]) n_context = self.n_context torch.cuda.empty_cache() # Prepare source (keys) and target (query) frame features key_indices = context_index_bank(n_context, self.long_mem, N) key_indices = torch.cat(key_indices, dim=-1) keys = feats[:, :, key_indices] query = feats[:, :, n_context:] # Make spatial radius mask TODO use torch.sparse if self.mask is None or self.mask_hw != (h, w): restrict = MaskedAttention(self.radius, flat=False) D = restrict.mask(h, w)[None] D = D.flatten(-4, -3).flatten(-2) D[D==0] = -1e10; D[D==1] = 0 self.mask = D.cuda() self.mask_hw = (h, w) # Flatten source frame features to make context feature set keys, query = keys.flatten(-2), query.flatten(-2) Ws, Is = mem_efficient_batched_affinity(query, keys, self.mask, \ self.temperature, self.topk, self.long_mem) ################################################################## # Propagate Labels and Save Predictions ################################################################### masks_idx = torch.LongTensor(N,H,W) masks_prob = torch.FloatTensor(N,H,W) for t in range(key_indices.shape[0]): # Soft labels of source nodes ctx_lbls = lbls[key_indices[t]].cuda() ctx_lbls = ctx_lbls.flatten(0, 2).transpose(0, 1) # Weighted sum of top-k neighbours (Is is index, Ws is weight) pred = (ctx_lbls[:, Is[t]] * Ws[t][None].cuda()).sum(1) pred = pred.view(-1, h, w) pred = pred.permute(1,2,0) if t > 0: lbls[t + n_context] = pred else: pred = lbls[0] lbls[t + n_context] = pred if self.norm_mask: pred[:, :, :] -= pred.min(-1)[0][:, :, None] pred[:, :, :] /= pred.max(-1)[0][:, :, None] # Adding Predictions pred_ = pred.permute([2,0,1])[None, ...] pred_up = F.interpolate(pred_, (H,W), mode="bilinear", align_corners=True) pred_up = pred_up[0].cpu() masks_idx[t] = pred_up.argmax(0) masks_prob[t] = pred_up[1] out = {} out["masks_pred_idx"] = masks_idx out["masks_pred_conf"] = masks_prob return out def context_index_bank(n_context, long_mem, N): ''' Construct bank of source frames indices, for each target frame ''' ll = [] # "long term" context (i.e. first frame) for t in long_mem: assert 0 <= t < N, 'context frame out of bounds' idx = torch.zeros(N, 1).long() if t > 0: idx += t + (n_context+1) idx[:n_context+t+1] = 0 ll.append(idx) # "short" context ss = [(torch.arange(n_context)[None].repeat(N, 1) + torch.arange(N)[:, None])[:, :]] return ll + ss def batched_affinity(query, keys, mask, temperature, topk, long_mem, device): ''' Mini-batched computation of affinity, for memory efficiency (less aggressively mini-batched) ''' A = torch.einsum('ijklm,ijkn->iklmn', keys, query) # Mask A[0, :, len(long_mem):] += mask.to(device) _, N, T, h1w1, hw = A.shape A = A.view(N, T*h1w1, hw) A /= temperature weights, ids = torch.topk(A, topk, dim=-2) weights = F.softmax(weights, dim=-2) Ws = [w for w in weights] Is = [ii for ii in ids] return Ws, Is def mem_efficient_batched_affinity(query, keys, mask, temperature, topk, long_mem): ''' Mini-batched computation of affinity, for memory efficiency ''' bsize, pbsize = 2, 100 Ws, Is = [], [] for b in range(0, keys.shape[2], bsize): _k, _q = keys[:, :, b:b+bsize].cuda(), query[:, :, b:b+bsize].cuda() w_s, i_s = [], [] for pb in range(0, _k.shape[-1], pbsize): A = torch.einsum('ijklm,ijkn->iklmn', _k, _q[..., pb:pb+pbsize]) A[0, :, len(long_mem):] += mask[..., pb:pb+pbsize] _, N, T, h1w1, hw = A.shape A = A.view(N, T*h1w1, hw) A /= temperature weights, ids = torch.topk(A, topk, dim=-2) weights = F.softmax(weights, dim=-2) w_s.append(weights) i_s.append(ids) weights = torch.cat(w_s, dim=-1) ids = torch.cat(i_s, dim=-1) Ws += [w for w in weights] Is += [ii for ii in ids] return Ws, Is class MaskedAttention(nn.Module): ''' A module that implements masked attention based on spatial locality TODO implement in a more efficient way (torch sparse or correlation filter) ''' def __init__(self, radius, flat=True): super(MaskedAttention, self).__init__() self.radius = radius self.flat = flat self.masks = {} self.index = {} def mask(self, H, W): if not ('%s-%s' %(H,W) in self.masks): self.make(H, W) return self.masks['%s-%s' %(H,W)] def index(self, H, W): if not ('%s-%s' %(H,W) in self.index): self.make_index(H, W) return self.index['%s-%s' %(H,W)] def make(self, H, W): if self.flat: H = int(H**0.5) W = int(W**0.5) gx, gy = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) D = ( (gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2 ).float() ** 0.5 D = (D < self.radius)[None].float() if self.flat: D = self.flatten(D) self.masks['%s-%s' %(H,W)] = D return D def flatten(self, D): return torch.flatten(torch.flatten(D, 1, 2), -2, -1) def make_index(self, H, W, pad=False): mask = self.mask(H, W).view(1, -1).byte() idx = torch.arange(0, mask.numel())[mask[0]][None] self.index['%s-%s' %(H,W)] = idx return idx def forward(self, x): H, W = x.shape[-2:] sid = '%s-%s' % (H,W) if sid not in self.masks: self.masks[sid] = self.make(H, W).to(x.device) mask = self.masks[sid] return x * mask[0] ================================================ FILE: launch/infer_vos.sh ================================================ #!/bin/bash # # Arguments # # suffix for the output directory (see below) VER=v01 # defines the path to the snapshot (see below) EXP=ours RUN_ID=final # SNAPSHOT name. Note .pth will be attached # See README.md to download these snapshots SNAPSHOT=ytvos_e060_res4 # YouTube-VOS #SNAPSHOT=trackingnet_e088_res4 # TrackingNet #SNAPSHOT=oxuva_e430_res4 # OxUvA #SNAPSHOT=kinetics_e026_res4 # Kinetics # codename of the final output layer [res3|res4|key] KEY=res4 # # Changing the following is not necessary # DS=$1 case $DS in ytvos) echo "Test dataset: YouTube-VOS 2018 (val)" FILELIST=filelists/val_ytvos2018_test ;; davis) echo "Test dataset: DAVIS-2017 (val)" FILELIST=filelists/val_davis2017_test ;; *) echo "Dataset '$DS' not recognised. Should be one of [ytvos|davis]." exit 1 ;; esac # The config file is irrelevant # since the inference parameters are # always the same CONFIG=configs/ytvos.yaml OUTPUT_DIR=./results EXTRA="$EXTRA --seed 0 --set TEST.KEY $KEY" SAVE_ID=${RUN_ID}_${SNAPSHOT}_${VER} SNAPSHOT_PATH=snapshots/${EXP}/${RUN_ID}/${SNAPSHOT}.pth if [ ! -f $SNAPSHOT_PATH ]; then echo "Snapshot $SNAPSHOT_PATH NOT found." exit 1; fi # # Code goes here # LISTNAME=`basename $FILELIST .txt` SAVE_DIR=$OUTPUT_DIR/$EXP/$SAVE_ID/$LISTNAME LOG_FILE=$OUTPUT_DIR/$EXP/$SAVE_ID/${LISTNAME}_${KEY}.log NUM_THREADS=12 export OMP_NUM_THREADS=$NUM_THREADS export MKL_NUM_THREADS=$NUM_THREADS CMD="python infer_vos.py --cfg $CONFIG \ --exp $EXP \ --run $RUN_ID \ --resume $SNAPSHOT_PATH \ --infer-list $FILELIST \ --mask-output-dir $SAVE_DIR \ $EXTRA" if [ ! -d $SAVE_DIR ]; then echo "Creating directory: $SAVE_DIR" mkdir -p $SAVE_DIR else echo "Saving to: $SAVE_DIR" fi git rev-parse HEAD > ${SAVE_DIR}.head git diff > ${SAVE_DIR}.diff echo $CMD > ${SAVE_DIR}.cmd echo $CMD nohup $CMD > $LOG_FILE 2>&1 & sleep 1 tail -f $LOG_FILE ================================================ FILE: launch/train.sh ================================================ #!/bin/bash # Set the following variables # The tensorboard logging will be creating in logs// # The snapshots will be saved in snapshots// EXP=v01_vos EXP_ID=v01_00_base # # No change are necessary starting here # DS=$1 SEED=32 case $DS in oxuva) echo "Train dataset: OxUvA" TASK="OxUvA_all" CFG=configs/oxuva.yaml EXP_ID="OX_${EXP_ID}" ;; ytvos) echo "Train dataset: YouTube-VOS" TASK="YTVOS" CFG=configs/ytvos.yaml EXP_ID="YT_${EXP_ID}" ;; track) echo "Train dataset: TrackingNet" TASK="TrackingNet" CFG=configs/tracknet.yaml EXP_ID="TN_${EXP_ID}" ;; kinetics) echo "Train dataset: Kinetics-400" CFG=configs/kinetics.yaml EXP_ID="KT_${EXP_ID}" ;; *) echo "Dataset '$DS' not recognised. Should be one of [oxuva|ytvos|track|kinetics]." exit 1 ;; esac CURR_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source $CURR_DIR/utils.bash CMD="python train.py --cfg $CFG --exp $EXP --run $EXP_ID --seed $SEED" LOG_DIR=logs/${EXP} LOG_FILE=$LOG_DIR/${EXP_ID}.log echo "LOG: $LOG_FILE" check_rundir $LOG_DIR $EXP_ID NUM_THREADS=12 export OMP_NUM_THREADS=$NUM_THREADS export MKL_NUM_THREADS=$NUM_THREADS echo $CMD CMD_FILE=$LOG_DIR/${EXP_ID}.cmd echo $CMD > $CMD_FILE git rev-parse HEAD > $LOG_DIR/${EXP_ID}.head git diff > $LOG_DIR/${EXP_ID}.diff nohup $CMD > $LOG_FILE 2>&1 & sleep 1 tail -f $LOG_FILE ================================================ FILE: launch/utils.bash ================================================ #!/bin/bash check_rundir() { LOG_DIR="$1" EXP_ID="$2" if [ ! -d "$LOG_DIR" ]; then echo "Creating directory $LOG_DIR" mkdir -p $LOG_DIR else LOGD=$LOG_DIR/$EXP_ID if [ -d "$LOGD" ]; then echo "Directory $LOGD already exists." read -p "Do you want to remove the log files?: " -n 1 -r echo if [[ ! $REPLY =~ ^[Yy]$ ]] then exit; else rm -rf $LOGD fi fi fi } ================================================ FILE: models/__init__.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ from .resnet18 import resnet18 from .net import Net from .framework import Framework def get_model(cfg, *args, **kwargs): backbones = { 'resnet18': resnet18 } def create_net(): backbone = backbones[cfg.MODEL.ARCH.lower()](*args, **kwargs) return Net(cfg, backbone) net = create_net() return Framework(cfg, net) ================================================ FILE: models/base.py ================================================ """ Base class for network models Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import torch import torch.nn as nn import torch.nn.functional as F class BaseNet(nn.Module): _trainable = (nn.Linear, nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d, nn.SyncBatchNorm) _batchnorm = (nn.BatchNorm2d, nn.SyncBatchNorm, nn.GroupNorm) def __init__(self): super().__init__() # we may want a different learning rate # for new layers self.from_scratch_layers = [] # we may want to freeze some layers self.not_training = [] # we may want to freeze BN (means/stds only) self.bn_freeze = [] def lr_mult(self): """Learning rate multiplier for weights. Returns: [old, new]""" return 1., 1. def lr_mult_bias(self): """Learning rate multiplier for bias. Returns: [old, new]""" return 2., 2. def _is_learnable(self, layer): return isinstance(layer, BaseNet._trainable) def _from_scratch(self, net, ignore=[]): for layer in net.modules(): if self._is_learnable(layer): self.from_scratch_layers.append(layer) def _freeze_bn(self, net, ignore=[]): """Add layers to use in .eval() mode only""" for layer in net.modules(): if isinstance(layer, BaseNet._batchnorm) and \ not layer in ignore: assert hasattr(layer, "eval") and callable(layer.eval) print("Freezing ", layer) self.bn_freeze.append(layer) print("Frozen BN: ", len(self.bn_freeze)) def _fix_bn(self, layer): if isinstance(layer, nn.BatchNorm2d): self.not_training.append(layer) elif isinstance(layer, nn.Module): for c in layer.children(): self._fix_bn(c) def __set_grad_mode(self, layer, mode, only_type=None): if hasattr(layer, "weight"): if only_type is None or isinstance(layer, only_type): layer.weight.requires_grad = mode if hasattr(layer, "bias") and not layer.bias is None: if only_type is None or isinstance(layer, only_type): layer.bias.requires_grad = mode if isinstance(layer, nn.Module): for c in layer.children(): self.__set_grad_mode(c, mode) def train(self, mode=True): super().train(mode) # some layers have to be frozen for layer in self.not_training: self.__set_grad_mode(layer, False) for layer in self.bn_freeze: assert hasattr(layer, "eval") and callable(layer.eval) layer.eval() def parameter_groups(self, base_lr, wd): w_old, w_new = self.lr_mult() b_old, b_new = self.lr_mult_bias() groups = ({"params": [], "weight_decay": wd, "lr": w_old*base_lr}, # weight learning {"params": [], "weight_decay": 0.0, "lr": b_old*base_lr}, # bias learning {"params": [], "weight_decay": wd, "lr": w_new*base_lr}, # weight finetuning {"params": [], "weight_decay": 0.0, "lr": b_new*base_lr}) # bias finetuning for m in self.modules(): if not self._is_learnable(m): if hasattr(m, "weight") or hasattr(m, "bias"): print("Skipping layer with parameters: ", m) continue if not m.weight is None and m.weight.requires_grad: if m in self.from_scratch_layers: groups[2]["params"].append(m.weight) else: groups[0]["params"].append(m.weight) elif not m.weight is None: print("Skipping W: ", m, m.weight.size()) if m.bias is not None and m.bias.requires_grad: if m in self.from_scratch_layers: groups[3]["params"].append(m.bias) else: groups[1]["params"].append(m.bias) elif m.bias is not None: print("Skipping b: ", m, m.bias.size()) return groups @staticmethod def _resize_as(x, y): return F.interpolate(x, y.size()[-2:], mode="bilinear", align_corners=True) ================================================ FILE: models/framework.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import torch import torch.nn as nn import torch.nn.functional as F from models.base import BaseNet class Framework(BaseNet): def __init__(self, cfg, net): super(Framework, self).__init__() self.cfg = cfg self.fast_net = net self.eye = None def parameter_groups(self, base_lr, wd): return self.fast_net.parameter_groups(base_lr, wd) def _align(self, x, t): tf = F.affine_grid(t, size=x.size(), align_corners=False) return F.grid_sample(x, tf, align_corners=False, mode="nearest") def _key_val(self, ctr, q): """ Args: ctr: [N,K] q: [BHW,K] Returns: val: [BHW,N] """ # [BHW,K] x [N,K].t -> [BHWxN] vals = torch.mm(q, ctr.t()) # [BHW,N] # normalising attention return vals / self.cfg.TEST.TEMP def _sample_index(self, x, T, N): """Sample indices of the anchors Args: x: [BT,K,H,W] Returns: index: [B,N*N,K] """ BT,K,H,W = x.shape B = x.view(-1,T,K,H*W).shape[0] # sample indices from a uniform grid xs, ys = W // N, H // N x_sample = torch.arange(0, W, xs).view(1, 1, N) y_sample = torch.arange(0, H, ys).view(1, N, 1) # Random offsets # [B x 1 x N] x_sample = x_sample + torch.randint(0, xs, (B, 1, 1)) # [B x N x 1] y_sample = y_sample + torch.randint(0, ys, (B, 1, 1)) # batch index # [B x N x N] hw_index = torch.LongTensor(x_sample + y_sample * W) return hw_index def _sample_from(self, x, index, T, N): """Gather the features based on the index Args: x: [BT,K,H,W] index: [B,N,N] defines the indices of NxN grid for a single frame in each of B video clips Returns: anchors: [BNN,K] sampled features given by index from x """ BT,K,H,W = x.shape x = x.view(-1,T,K,H*W) B = x.shape[0] # > [B,T,K,HW] > [B,T,HW,K] > [B,THW,K] x = x.permute([0,1,3,2]).reshape(B,-1,K) # every video clip will have the same samples # on the grid # [B x N x N] -> [B x N*N x 1] -> [B x N*N x K] index = index.view(B,-1,1).expand(-1,-1,K) # selecting from the uniform grid y = x.gather(1, index.to(x.device)) # [BNN,K] return y.flatten(0,1) def _mark_from(self, x, index, T, N, fill_value=0): """This is analogous to _sample_from except that here we simply "mark" the sampled positions in the tensor Used for visualisation only. Since it is a binary mask, K == 1 Args: x: [BT,1,H,W] binary mask index: [B,N,N] defines the indices of NxN grid for a single frame in each of B video clips Returns: y: [BT,1,H,W] marked sample positions """ BT,K,H,W = x.shape assert K == 1, "Expected binary mask" x = x.view(-1,T,K,H*W) B = x.shape[0] # > [B,T,K,HW] > [B,T,HW,K] > [B,THW,K] x = x.permute([0,1,3,2]).reshape(B,-1,K) # every video clip will have the same samples # on the grid # [B x N x N] -> [B x N*N x 1] -> [B x N*N x K] index = index.view(B,-1,1).expand(-1,-1,K) # selecting from the uniform grid # [B x T*H*W x K] y = x.scatter(1, index.to(x.device), fill_value) # [B x T*H*W x K] -> [BT x K x H x W] return y.view(-1,H*W,K).permute([0,2,1]).view(-1,K,H,W) def _cluster_grid(self, k1, k2, aff1, aff2, T, index=None): """ Selecting clusters within a sequence Args: k1: [BT,K,H,W] k2: [BT,K,H,W] """ BT,K,H,W = k1.shape assert BT % T == 0, "Batch not divisible by sequence length" B = BT // T # N = [G x G] N = self.cfg.MODEL.GRID_SIZE ** 2 # [BT,K,H,W] -> [BTHW,K] flatten = lambda x: x.flatten(2,3).permute([0,2,1]).flatten(0,1) # [BTHW,BN] -> [BT,BN,H,W] def unflatten(x, aff=None): x = x.view(BT,H*W,-1).permute([0,2,1]).view(BT,-1,H,W) if aff is None: return x return self._align(x, aff) index = self._sample_index(k1, T, N = self.cfg.MODEL.GRID_SIZE) query1 = self._sample_from(k1, index, T, N = self.cfg.MODEL.GRID_SIZE) """Computing the distances and pseudo labels""" # [BTHW,K] k1_ = flatten(k1) k2_ = flatten(k2) # [BTHW,BN] -> [BTHW,BN] -> [BT,BN,H,W] vals_soft = unflatten(self._key_val(query1, k1_), aff1) vals_pseudo = unflatten(self._key_val(query1, k2_), aff2) # [BT,BN,H,W] probs_pseudo = self._pseudo_mask(vals_pseudo, T) probs_pseudo2 = self._pseudo_mask(vals_soft, T) pseudo = probs_pseudo.argmax(1) pseudo2 = probs_pseudo2.argmax(1) # mask def grid_mask(): grid_mask = torch.ones(BT,1,H,W).to(pseudo.device) return self._mark_from(grid_mask, index, T, N = self.cfg.MODEL.GRID_SIZE) return vals_soft, pseudo, index, [vals_pseudo, pseudo2, grid_mask] # sampling affinity def _aff_sample(self, k1, k2, T): BT,K,h,w = k1.size() B = BT // T hw = h*w def gen(query): grid_mask = torch.ones(B,1,hw).to(k1.device) # generating random indices indices = torch.randint(0, hw, (B,1,1)).to(k1.device) grid_mask.scatter_(2, indices, 0) # [B,K,H,W] -> [B,K,1] query_ = query[::T].view(B,K,-1).gather(2, indices.expand(-1,K,-1)) def aff(keys): k = keys.view(B,T,K,-1) # [B,T,K,HW] x [B,1,K,HW] -> [B,T,HW] aff = (k * query_[:,None,:,:]).sum(2) return (aff + 1) / 2 aff1 = aff(k1) aff2 = aff(k2) return grid_mask.view(B,h,w), aff1.view(BT,h,w), aff2.view(BT,h,w) grid_mask1, aff1_1, aff1_2 = gen(k1) grid_mask2, aff2_1, aff2_2 = gen(k2) return grid_mask1, aff1_1, aff1_2, \ grid_mask2, aff2_1, aff2_2 def _pseudo_mask(self, logits, T): BT,K,h,w = logits.shape assert BT % T == 0, "Batch not divisible by sequence length" B = BT // T # N = [G x G] N = self.cfg.MODEL.GRID_SIZE ** 2 # generating a pseudo label # first we need to mask out the affinities across the batch if self.eye is None or self.eye.shape[0] != B*T \ or self.eye.shape[1] != B*N: eye = torch.eye(B)[:,:,None].expand(-1,-1,N).reshape(B,-1) eye = eye.unsqueeze(1).expand(-1,T,-1).reshape(B*T, B*N, 1, 1) self.eye = eye.to(logits.device) probs = F.softmax(logits, 1) return probs * self.eye def _ref_loss(self, x, y, N = 4): B,_,h,w = x.shape index = self._sample_index(x, T=1, N=N) x1 = self._sample_from(x, index, T=1, N=N) y1 = self._sample_from(y, index, T=1, N=N) logits = torch.mm(x1, y1.t()) / self.cfg.TEST.TEMP labels = torch.arange(logits.size(1)).to(logits.device) return F.cross_entropy(logits, labels) def _ce_loss(self, x, pseudo_map, T, eps=1e-5): error_map = F.cross_entropy(x, pseudo_map, reduction="none", ignore_index=-1) BT,h,w = error_map.shape errors = error_map.view(-1,T,h,w) error_ref, error_t = errors[:,0], errors[:,1:] return error_ref.mean(), error_t.mean(), error_map def _forward_reg(self, frames2, norm): losses = {} if not self.cfg.TRAIN.STOP_GRAD: k2, res3, res4 = self.fast_net(frames2, norm) return k2, res3, res4, losses training = self.fast_net.training if self.cfg.TRAIN.BLOCK_BN: self.fast_net.eval() with torch.no_grad(): k2, res3, res4 = self.fast_net(frames2, norm) if self.cfg.TRAIN.BLOCK_BN: self.fast_net.train(training) return k2, res3, res4, losses def fetch_first(self, x1, x2, T): assert x1.shape[1:] == x2.shape[1:] c,h,w = x1.shape[1:] x1 = x1.view(-1,T+1,c,h,w) x2 = x2.view(-1,T-1,c,h,w) x2 = torch.cat((x1[:,-1:], x2), 1) x1 = x1[:,:-1] return x1.flatten(0,1), x2.flatten(0,1) def forward(self, frames, frames2=None, mask=None, T=None, affine=None, affine2=None, embd_only=False, norm=True, dbg=False): """Extract temporal correspondences Args: frames: [B,T,C,H,W] Returns: losses: a dictionary with the embedding loss net_outs: feature embeddings """ # embedding for self-supervised learning key1, res3, res4 = self.fast_net(frames, norm) outs, losses = {}, {} if embd_only: # only embedding return res3, res4, key1 else: key2, res3_2, res4_2, losses = self._forward_reg(frames2, norm) # fetching the first frame from the second view key1, key2 = self.fetch_first(key1, key2, T) vals, pseudo, index, dbg_info = self._cluster_grid(key1, key2, affine, affine2, T) vals_pseudo, pseudo2, grid_mask = dbg_info key1_aligned = self._align(key1, affine) key2_aligned = self._align(key2, affine2) n_ref = self.cfg.MODEL.GRID_SIZE_REF losses["cross_key"] = self._ref_loss(key1_aligned[::T], key2_aligned[::T], N = n_ref) # losses _, losses["temp"], outs["error_map"] = self._ce_loss(vals, pseudo, T) # computing the main loss losses["main"] = self.cfg.MODEL.CE_REF * losses["cross_key"] + losses["temp"] if dbg: vals = F.softmax(vals, 1) vals_pseudo = F.softmax(vals_pseudo, 1) frames, frames2 = self.fetch_first(frames, frames2, T) outs["frames_orig"] = frames outs["frames"] = self._align(frames, affine) outs["frames2"] = self._align(frames2, affine2) outs["map_soft"] = vals outs["map"] = pseudo outs["map_target_soft"] = vals_pseudo outs["map_target"] = pseudo2 outs["grid_mask"] = grid_mask() outs["aff_mask1"], outs["aff11"], outs["aff12"], \ outs["aff_mask2"], outs["aff21"], outs["aff22"] = self._aff_sample(key1, key2, T) return losses, outs ================================================ FILE: models/net.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import torch import torch.nn as nn import torch.nn.functional as F from models.base import BaseNet class MLP(nn.Sequential): def __init__(self, n_in, n_out): super().__init__() self.add_module("conv1", nn.Conv2d(n_in, n_in, 1, 1)) self.add_module("bn1", nn.BatchNorm2d(n_in)) self.add_module("relu", nn.ReLU(True)) self.add_module("conv2", nn.Conv2d(n_in, n_out, 1, 1)) class Net(BaseNet): def __init__(self, cfg, backbone): super(Net, self).__init__() self.cfg = cfg self.backbone = backbone self.emb_q = MLP(backbone.fdim, cfg.MODEL.FEATURE_DIM) def lr_mult(self): """Learning rate multiplier for weights. Returns: [old, new]""" return 1., 1. def lr_mult_bias(self): """Learning rate multiplier for bias. Returns: [old, new]""" return 2., 2. def forward(self, frames, norm=True): """Forward pass to extract projection and task features""" # extracting the time dimension res4, res3 = self.backbone(frames) # B,K,H,W query = self.emb_q(res4) if norm: query = F.normalize(query, p=2, dim=1) res3 = F.normalize(res3, p=2, dim=1) res4 = F.normalize(res4, p=2, dim=1) return query, res3, res4 ================================================ FILE: models/resnet18.py ================================================ """ Based on Jabri et al., (2020) Credit: https://github.com/ajabri/videowalk.git License: MIT """ import os import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models.resnet as torch_resnet from torchvision.models.resnet import BasicBlock class ResNet(torch_resnet.ResNet): def __init__(self, *args, **kwargs): super(ResNet, self).__init__(*args, **kwargs) def filter_layers(self, x): return [l for l in x if getattr(self, l) is not None] def remove_layers(self, remove_layers=[]): # Remove extraneous layers remove_layers += ['fc', 'avgpool'] for layer in self.filter_layers(remove_layers): setattr(self, layer, None) def modify(self): # Set stride of layer3 and layer 4 to 1 (from 2) for layer in self.filter_layers(['layer3']): for m in getattr(self, layer).modules(): if isinstance(m, torch.nn.Conv2d): m.stride = (1, 1) for layer in self.filter_layers(['layer4']): for m in getattr(self, layer).modules(): if isinstance(m, torch.nn.Conv2d): m.stride = (1, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = x if self.maxpool is None else self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x3 = self.layer3(x) x4 = self.layer4(x3) return x4, x3 def _resnet(arch, block, layers, pretrained, **kwargs): model = ResNet(block, layers, **kwargs) return model def resnet18(pretrained='', remove_layers=[], train=True, **kwargs): model = _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs) model.modify() model.remove_layers(remove_layers) setattr(model, "fdim", 512) return model ================================================ FILE: opts.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ from __future__ import print_function import os import torch import argparse from core.config import cfg def add_global_arguments(parser): # # Model details # parser.add_argument("--snapshot-dir", type=str, default='./snapshots', help="Where to save snapshots of the model.") parser.add_argument("--logdir", type=str, default='./logs', help="Where to save log files of the model.") parser.add_argument("--exp", type=str, default="main", help="ID of the experiment (multiple runs)") parser.add_argument("--run", type=str, help="ID of the run") parser.add_argument('--workers', type=int, default=8, metavar='N', help='dataloader threads') parser.add_argument('--seed', default=64, type=int, help='seed for initializing training. ') # # Inference only # parser.add_argument("--infer-list", default="voc12/val.txt", type=str) parser.add_argument('--mask-output-dir', type=str, default=None, help='path where to save masks') parser.add_argument("--resume", type=str, default=None, help="Snapshot \"ID,iter\" to load") # # Configuration # parser.add_argument( '--cfg', dest='cfg_file', required=True, help='Config file for training (and optionally testing)') parser.add_argument( '--set', dest='set_cfgs', help='Set config keys. Key value sequence seperate by whitespace.' 'e.g. [key] [value] [key] [value]', default=[], nargs='+') def maybe_create_dir(path): if not os.path.exists(path): os.makedirs(path) def check_global_arguments(args): args.cuda = torch.cuda.is_available() print("Available threads: ", torch.get_num_threads()) args.logdir = os.path.join(args.logdir, args.exp, args.run) maybe_create_dir(args.logdir) # # Model directories # args.snapshot_dir = os.path.join(args.snapshot_dir, args.exp, args.run) maybe_create_dir(args.snapshot_dir) def get_arguments(args_in): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Dense Unsupervised Learning for Video Segmentation") add_global_arguments(parser) args = parser.parse_args(args_in) check_global_arguments(args) return args ================================================ FILE: requirements.txt ================================================ # This file may be used to create an environment using: # $ conda create --name --file # platform: linux-64 setproctitle matplotlib tensorboard pyyaml packaging opencv-python scikit-image ================================================ FILE: train.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ from __future__ import print_function import os import sys import numpy as np import time import random import setproctitle from functools import partial import torch import torch.nn.functional as F from datasets import * from models import get_model from base_trainer import BaseTrainer from opts import get_arguments from core.config import cfg, cfg_from_file, cfg_from_list from utils.timer import Timer from utils.stat_manager import StatManager from utils.davis2017 import evaluate_semi from labelprop.crw import CRW from torch.utils.tensorboard import SummaryWriter torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False class Trainer(BaseTrainer): def __init__(self, args, cfg): super(Trainer, self).__init__(args, cfg) # train loader for target domain self.loader = get_dataloader(args, cfg, 'train') # alias self.denorm = self.loader.dataset.denorm # val loaders for source and target domains self.valloaders = get_dataloader(args, cfg, 'val') # writers (only main) self.writer_val = {} for val_set in self.valloaders.keys(): logdir_val = os.path.join(args.logdir, val_set) self.writer_val[val_set] = SummaryWriter(logdir_val) # model self.net = get_model(cfg, remove_layers=cfg.MODEL.REMOVE_LAYERS) print("Train Net: ") print(self.net) # optimizer using different LR net_params = self.net.parameter_groups(cfg.MODEL.LR, cfg.MODEL.WEIGHT_DECAY) print("Optimising parameter groups: ") for i, g in enumerate(net_params): print("[{}]: # parameters: {}, lr = {:4.3e}".format(i, len(g["params"]), g["lr"])) self.optim = self.get_optim(net_params, cfg.MODEL) print("# of params: ", len(list(self.net.parameters()))) # LR scheduler if cfg.MODEL.LR_SCHEDULER == "step": self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, \ step_size=cfg.MODEL.LR_STEP, \ gamma=cfg.MODEL.LR_GAMMA) elif cfg.MODEL.LR_SCHEDULER == "linear": # linear decay def lr_lambda(epoch): mult = 1 - epoch / (float(self.cfg.TRAIN.NUM_EPOCHS) - self.start_epoch) mult = mult ** self.cfg.MODEL.LR_POWER #print("Linear Scheduler: mult = {:4.3f}".format(mult)) return mult self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda) else: self.scheduler = None self.vis_batch = None # using cuda self.net.cuda() self.crw = CRW(cfg.TEST) # checkpoint management self.checkpoint.create_model(self.net, self.optim) if not args.resume is None: self.start_epoch, self.best_score = self.checkpoint.load(args.resume, "cuda:0") def step_seg(self, epoch, batch_src, key, temp=None, train=False, visualise=False, \ save_batch=False, writer=None, tag="train_src"): frames, masks_gt, n_obj, seq_name = batch_src # semi-supervised: select only the first frames = frames.flatten(0,1) masks_gt = masks_gt.flatten(0,1) masks_gt = masks_gt[:, :n_obj.item()] masks_ref = masks_gt.clone() masks_ref[1:] *= 0 T = frames.shape[0] fetch = {"res3": lambda x: x[0], \ "res4": lambda x: x[1], \ "key": lambda x: x[2]} # number of iterations bs = self.cfg.TRAIN.BATCH_SIZE feats = [] t0 = time.time() torch.cuda.empty_cache() for t in range(0, T, bs): # next frame frames_batch = frames[t:t+bs].cuda() # source forward pass feats_ = self.net(frames_batch, embd_only=True) feats.append(fetch[key](feats_).cpu()) feats = torch.cat(feats, 0) print("Inference: {:4.3f}s".format(time.time() - t0)) sys.stdout.flush() t0 = time.time() outs = self.crw.forward(feats, masks_ref) print("CRW propagation: {:4.3f}s".format(time.time() - t0)) sys.stdout.flush() outs["masks_gt"] = masks_gt.argmax(1) if visualise: outs["frames"] = frames self._visualise_seg(epoch, outs, writer, tag) if save_batch: # Saving batch for visualisation # saving the batch for visualisation self.save_vis_batch(tag, batch_src) return outs def step(self, epoch, batch_in, train=False, visualise=False, save_batch=False, writer=None, tag="train"): frames1, frames2, affine1, affine2 = batch_in assert frames1.size() == frames2.size(), "Frames shape mismatch" # We could simply do # images1 = frames1.flatten(0,1).cuda() # images2 = frames2.flatten(0,1).cuda() # Instead we pull the reference frame from the 2nd view # to the first view so that the regularising branch is # always in evaluation mode to save the GPU memory B,T,C,H,W = frames1.shape images1 = torch.cat((frames1, frames2[:, ::T]), 1) images1 = images1.flatten(0,1).cuda() images2 = frames2[:, 1:].flatten(0,1).cuda() affine1 = affine1.flatten(0,1).cuda() affine2 = affine2.flatten(0,1).cuda() # source forward pass losses, outs = self.net(images1, frames2=images2, T=T, \ affine=affine1, affine2=affine2, \ dbg=visualise) if train: self.optim.zero_grad() losses["main"].backward() self.optim.step() if visualise: self._visualise(epoch, outs, T, writer, tag) if save_batch: # Saving batch for visualisation self.save_vis_batch(tag, batch_in) # summarising the losses # into python scalars losses_ret = {} for key, val in losses.items(): losses_ret[key] = val.mean().item() return losses_ret, outs def train_epoch(self, epoch): stat = StatManager() # adding stats for classes timer = Timer("Epoch {}".format(epoch)) step = partial(self.step, train=True, visualise=False) # training mode self.net.train() for i, batch in enumerate(self.loader): save_batch = i == 0 # # Forward pass # losses, _ = step(epoch, batch, save_batch=save_batch, tag="train") for loss_key, loss_val in losses.items(): stat.update_stats(loss_key, loss_val) # intermediate logging if i % 10 == 0: msg = "Loss [{:04d}]: ".format(i) for loss_key, loss_val in losses.items(): msg += " {} {:.4f} | ".format(loss_key, loss_val) msg += " | Im/Sec: {:.1f}".format(i * self.cfg.TRAIN.BATCH_SIZE / timer.get_stage_elapsed()) print(msg) sys.stdout.flush() for name, val in stat.items(): print("{}: {:4.3f}".format(name, val)) self.writer.add_scalar('all/{}'.format(name), val, epoch) # plotting learning rate for ii, l in enumerate(self.optim.param_groups): print("Learning rate [{}]: {:4.3e}".format(ii, l['lr'])) self.writer.add_scalar('lr/enc_group_%02d' % ii, l['lr'], epoch) # plotting moment distance if stat.has_vals("lr_gamma"): self.writer.add_scalar('hyper/gamma', stat.summarize_key("lr_gamma"), epoch) if epoch % self.cfg.LOG.ITER_TRAIN == 0: self.visualise_results(epoch, self.writer, "train", self.step) def validation(self, epoch, writer, loader, tag=None, max_iter=None): stat = StatManager() if max_iter is None: max_iter = len(loader) # Fast test during the training def eval_batch(batch): loss, masks = self.step(epoch, batch, train=False, visualise=False) for loss_key, loss_val in loss.items(): stat.update_stats(loss_key, loss_val) return masks self.net.eval() print("Starting validation") sys.stdout.flush() for n, batch in enumerate(loader): with torch.no_grad(): # note video length assumed 1 eval_batch(batch) if not tag is None and not self.has_vis_batch(tag): self.save_vis_batch(tag, batch) checkpoint_score = 0.0 # total classification loss for stat_key, stat_val in stat.items(): writer.add_scalar('all/{}'.format(stat_key), stat_val, epoch) print('Loss {}: {:4.3f}'.format(stat_key, stat_val)) if not tag is None and epoch % self.cfg.LOG.ITER_TRAIN == 0: self.visualise_results(epoch, writer, tag, self.step) return checkpoint_score def validation_seg(self, epoch, writer, loader, key="all", temp=None, tag=None, max_iter=None): vis = key == "res4" stat = StatManager() if max_iter is None: max_iter = len(loader) if temp is None: temp = self.cfg.TEST.TEMP step_fn = partial(self.step_seg, key=key, temp=temp, train=False, visualise=vis, writer=writer) # Fast test during the training def eval_batch(n, batch): tag_n = tag + "_{:02d}".format(n) masks = step_fn(epoch, batch, tag=tag_n) return masks self.net.eval() def davis_mask(masks): masks = masks.cpu() num_objects = int(masks.max()) tmp = torch.ones(num_objects, *masks.shape) tmp = tmp * torch.arange(1, num_objects + 1)[:, None, None, None] return (tmp == masks[None, ...]).long().numpy() Js = {"M": [], "R": [], "D": []} Fs = {"M": [], "R": [], "D": []} timer = Timer("[Epoch {}] Validation-Seg".format(epoch)) tag_key = "{}_{}_{:3.2f}".format(tag, key, temp) for n, batch in enumerate(loader): seq_name = batch[-1][0] print("Sequence: ", seq_name) sys.stdout.flush() with torch.no_grad(): masks_out = eval_batch(n, batch) # second element is assumed to be always GT masks masks_gt = davis_mask(masks_out["masks_gt"]) masks_pred = davis_mask(masks_out["masks_pred_idx"]) assert masks_gt.shape == masks_pred.shape # converting to a digestible format # [num_objects, seq_length, height, width] if not tag_key is None and not self.has_vis_batch(tag_key): self.save_vis_batch(tag_key, batch) start_t = time.time() metrics_res = evaluate_semi((masks_gt, ), (masks_pred, )) J, F = metrics_res['J'], metrics_res['F'] print("Evaluation: {:4.3f}s".format(time.time() - start_t)) print("Jaccard: ", J["M"]) print("F-Score: ", F["M"]) for l in ("M", "R", "D"): Js[l] += J[l] Fs[l] += F[l] msg = "{} | Im/Sec: {:.1f}".format(n, n * batch[0].shape[1] / timer.get_stage_elapsed()) print(msg) sys.stdout.flush() g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] # Generate dataframe for the general results final_mean = (np.mean(Js["M"]) + np.mean(Fs["M"])) / 2. g_res = [final_mean, \ np.mean(Js["M"]), np.mean(Js["R"]), np.mean(Js["D"]), \ np.mean(Fs["M"]), np.mean(Fs["R"]), np.mean(Fs["D"])] for (name, val) in zip(g_measures, g_res): writer.add_scalar('{}_{:3.2f}/{}'.format(key, temp, name), val, epoch) print('{}: {:4.3f}'.format(name, val)) return final_mean def train(args, cfg): setproctitle.setproctitle("dense-ulearn | {}".format(args.run)) if args.seed is not None: print("Setting the seed: {}".format(args.seed)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) trainer = Trainer(args, cfg) timer = Timer() def time_call(func, msg, *args, **kwargs): timer.reset_stage() val = func(*args, **kwargs) print(msg + (" {:3.2}m".format(timer.get_stage_elapsed() / 60.))) return val for epoch in range(trainer.start_epoch, cfg.TRAIN.NUM_EPOCHS + 1): # training 1 epoch time_call(trainer.train_epoch, "Train epoch: ", epoch) print("Epoch >>> {:02d} <<< ".format(epoch)) if epoch % cfg.LOG.ITER_VAL == 0: for val_set in ("val_video", ): time_call(trainer.validation, "Validation / {} / Val: ".format(val_set), \ epoch, trainer.writer_val[val_set], trainer.valloaders[val_set], tag=val_set) best_layer = None best_score = -1e10 for val_set in ("val_video_seg", ): writer = trainer.writer_val[val_set] loader = trainer.valloaders[val_set] for layer in ("key", "res4"): msg = ">>> Validation {} / {} <<<".format(layer, val_set) score = time_call(trainer.validation_seg, msg, epoch, writer, loader, key=layer, tag=val_set) if score > best_score: best_score = score best_layer = layer print("Best score / layer: {:4.2f} / {}".format(best_score, best_layer)) if val_set =="val_video_seg": trainer.checkpoint_best(best_score, epoch, best_layer) if not trainer.scheduler is None and cfg.MODEL.LR_SCHED_USE_EPOCH: trainer.scheduler.step() def main(): args = get_arguments(sys.argv[1:]) # Reading the config cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) train(args, cfg) if __name__ == "__main__": main() ================================================ FILE: utils/checkpoints.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os import sys import torch class Checkpoint(object): def __init__(self, path, max_n=3): self.path = path self.max_n = max_n self.models = {} self.checkpoints = [] def create_model(self, model, opt): self.models = {} self.models['model'] = model self.models['opt'] = opt def limit(self): return self.max_n def __len__(self): return len(self.checkpoints) def _get_full_path(self, suffix): filename = self._filename(suffix) return os.path.join(self.path, filename) def clean(self): n_remove = max(0, len(self.checkpoints) - self.max_n) for i in range(n_remove): self._rm(self.checkpoints[i]) self.checkpoints = self.checkpoints[n_remove:] def _rm(self, suffix): path = self._get_full_path(suffix) if os.path.isfile(path): os.remove(path) def _filename(self, suffix): return "{}.pth".format(suffix) def load(self, path, location): if len(path) > 0 and not os.path.isfile(path): print("Snapshot {} not found".format(path)) data = torch.load(path, map_location=location) data_mapped = {} for key, val in data["model"].items(): data_mapped[key.replace("module.", "")] = val self.models["model"].load_state_dict(data_mapped, strict=True) return data["epoch"], data["score"] def checkpoint(self, score, epoch, t): suffix = "epoch{:03d}_score{:4.3f}_{}".format(epoch, score, t) self.checkpoints.append(suffix) path = self._get_full_path(suffix) if not os.path.isfile(path): torch.save({"model": self.models["model"].state_dict(), "opt": self.models["opt"].state_dict(), "score": score, "epoch": epoch}, path) # removing if more than allowed number of snapshots self.clean() ================================================ FILE: utils/collections.py ================================================ # Copyright (c) 2017-present, Facebook, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## """A simple attribute dictionary used for representing configuration options.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals class AttrDict(dict): IMMUTABLE = '__immutable__' def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__[AttrDict.IMMUTABLE] = False def __getattr__(self, name): if name in self.__dict__: return self.__dict__[name] elif name in self: return self[name] else: raise AttributeError(name) def __setattr__(self, name, value): if not self.__dict__[AttrDict.IMMUTABLE]: if name in self.__dict__: self.__dict__[name] = value else: self[name] = value else: raise AttributeError( 'Attempted to set "{}" to "{}", but AttrDict is immutable'. format(name, value) ) def immutable(self, is_immutable): """Set immutability to is_immutable and recursively apply the setting to all nested AttrDicts. """ self.__dict__[AttrDict.IMMUTABLE] = is_immutable # Recursively set immutable state for v in self.__dict__.values(): if isinstance(v, AttrDict): v.immutable(is_immutable) for v in self.values(): if isinstance(v, AttrDict): v.immutable(is_immutable) def is_immutable(self): return self.__dict__[AttrDict.IMMUTABLE] ================================================ FILE: utils/davis2017.py ================================================ """ Credit: https://github.com/davisvideochallenge/davis2017-evaluation.git License: BSD 3-Clause Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation """ import sys import numpy as np from utils.davis2017_metrics import db_eval_boundary, db_eval_iou import utils.davis2017_utils as utils def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): if all_res_masks.shape[0] > all_gt_masks.shape[0]: sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!") sys.exit() elif all_res_masks.shape[0] < all_gt_masks.shape[0]: sys.stdout.write("\nThe number of predictions is less than ground truth. Padding with zero.") zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2]) for ii in range(all_gt_masks.shape[0]): if 'J' in metric: j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) sys.stdout.flush() if 'F' in metric: f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) sys.stdout.flush() return j_metrics_res, f_metrics_res def evaluate_semi(all_gt_masks, all_res_masks, metric=('J', 'F'), debug=False): metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] if 'T' in metric: raise ValueError('Temporal metric not supported!') if 'J' not in metric and 'F' not in metric: raise ValueError('Metric possible values are J for IoU or F for Boundary') # Containers metrics_res = {} if 'J' in metric: metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} if 'F' in metric: metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}} for seq, (seq_gt_masks, seq_res_masks) in enumerate(zip(all_gt_masks, all_res_masks)): seq_gt_masks, seq_res_masks = seq_gt_masks[:, 1:-1, :, :], seq_res_masks[:, 1:-1, :, :] j_metrics_res, f_metrics_res = _evaluate_semisupervised(seq_gt_masks, seq_res_masks, None, metric) for ii in range(seq_gt_masks.shape[0]): seq_name = f'{seq}_{ii+1}' if 'J' in metric: [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) metrics_res['J']["M"].append(JM) metrics_res['J']["R"].append(JR) metrics_res['J']["D"].append(JD) metrics_res['J']["M_per_object"][seq_name] = JM if 'F' in metric: [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii]) metrics_res['F']["M"].append(FM) metrics_res['F']["R"].append(FR) metrics_res['F']["D"].append(FD) metrics_res['F']["M_per_object"][seq_name] = FM # Show progress if debug: sys.stdout.write(seq + '\n') sys.stdout.flush() return metrics_res ================================================ FILE: utils/davis2017_metrics.py ================================================ """ Credit: https://github.com/davisvideochallenge/davis2017-evaluation.git License: BSD 3-Clause Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation """ import math import sys import numpy as np import cv2 cv2.setNumThreads(0) from skimage.morphology import disk def db_eval_iou(annotation, segmentation, void_pixels=None): """ Compute region similarity as the Jaccard Index. Arguments: annotation (ndarray): binary annotation map. segmentation (ndarray): binary segmentation map. void_pixels (ndarray): optional mask with void pixels Return: jaccard (float): region similarity """ assert annotation.shape == segmentation.shape, \ f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' annotation = annotation.astype(np.bool) segmentation = segmentation.astype(np.bool) if void_pixels is not None: assert annotation.shape == void_pixels.shape, \ f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' void_pixels = void_pixels.astype(np.bool) else: void_pixels = np.zeros_like(segmentation) # Intersection between all sets inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) j = inters / union if j.ndim == 0: j = 1 if np.isclose(union, 0) else j else: j[np.isclose(union, 0)] = 1 return j def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): assert annotation.shape == segmentation.shape if void_pixels is not None: assert annotation.shape == void_pixels.shape if annotation.ndim == 3: n_frames = annotation.shape[0] f_res = np.zeros(n_frames) for frame_id in range(n_frames): void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) elif annotation.ndim == 2: f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) else: raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') return f_res def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): """ Compute mean,recall and decay from per-frame evaluation. Calculates precision/recall for boundaries between foreground_mask and gt_mask using morphological operators to speed it up. Arguments: foreground_mask (ndarray): binary segmentation image. gt_mask (ndarray): binary annotated image. void_pixels (ndarray): optional mask with void pixels Returns: F (float): boundaries F-measure """ assert np.atleast_3d(foreground_mask).shape[2] == 1 if void_pixels is not None: void_pixels = void_pixels.astype(np.bool) else: void_pixels = np.zeros_like(foreground_mask).astype(np.bool) bound_pix = bound_th if bound_th >= 1 else \ np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) # Get the pixel boundaries of both masks fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) # Get the intersection gt_match = gt_boundary * fg_dil fg_match = fg_boundary * gt_dil # Area of the intersection n_fg = np.sum(fg_boundary) n_gt = np.sum(gt_boundary) # % Compute precision and recall if n_fg == 0 and n_gt > 0: precision = 1 recall = 0 elif n_fg > 0 and n_gt == 0: precision = 0 recall = 1 elif n_fg == 0 and n_gt == 0: precision = 1 recall = 1 else: precision = np.sum(fg_match) / float(n_fg) recall = np.sum(gt_match) / float(n_gt) # Compute F measure if precision + recall == 0: F = 0 else: F = 2 * precision * recall / (precision + recall) return F def _seg2bmap(seg, width=None, height=None): """ From a segmentation, compute a binary boundary map with 1 pixel wide boundaries. The boundary pixels are offset by 1/2 pixel towards the origin from the actual segment boundary. Arguments: seg : Segments labeled from 1..k. width : Width of desired bmap <= seg.shape[1] height : Height of desired bmap <= seg.shape[0] Returns: bmap (ndarray): Binary boundary map. David Martin January 2003 """ seg = seg.astype(np.bool) seg[seg > 0] = 1 assert np.atleast_3d(seg).shape[2] == 1 width = seg.shape[1] if width is None else width height = seg.shape[0] if height is None else height h, w = seg.shape[:2] ar1 = float(width) / float(height) ar2 = float(w) / float(h) assert not ( width > w | height > h | abs(ar1 - ar2) > 0.01 ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) e = np.zeros_like(seg) s = np.zeros_like(seg) se = np.zeros_like(seg) e[:, :-1] = seg[:, 1:] s[:-1, :] = seg[1:, :] se[:-1, :-1] = seg[1:, 1:] b = seg ^ e | seg ^ s | seg ^ se b[-1, :] = seg[-1, :] ^ e[-1, :] b[:, -1] = seg[:, -1] ^ s[:, -1] b[-1, -1] = 0 if w == width and h == height: bmap = b else: bmap = np.zeros((height, width)) for x in range(w): for y in range(h): if b[y, x]: j = 1 + math.floor((y - 1) + height / h) i = 1 + math.floor((x - 1) + width / h) bmap[j, i] = 1 return bmap if __name__ == '__main__': from davis2017.davis import DAVIS from davis2017.results import Results dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics') results = Results(root_dir='examples/osvos') # Test timing F measure for seq in dataset.get_sequences(): all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True) all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] all_res_masks = results.read_masks(seq, all_masks_id) f_metrics_res = np.zeros(all_gt_masks.shape[:2]) for ii in range(all_gt_masks.shape[0]): f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...]) # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py # snakeviz f_measure.prof ================================================ FILE: utils/davis2017_utils.py ================================================ """ Credit: https://github.com/davisvideochallenge/davis2017-evaluation.git License: BSD 3-Clause Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation """ import os import errno import numpy as np from PIL import Image import warnings def _pascal_color_map(N=256, normalized=False): """ Python implementation of the color map function for the PASCAL VOC data set. Official Matlab version can be found in the PASCAL VOC devkit http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit """ def bitget(byteval, idx): return (byteval & (1 << idx)) != 0 dtype = 'float32' if normalized else 'uint8' cmap = np.zeros((N, 3), dtype=dtype) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << 7 - j) g = g | (bitget(c, 1) << 7 - j) b = b | (bitget(c, 2) << 7 - j) c = c >> 3 cmap[i] = np.array([r, g, b]) cmap = cmap / 255 if normalized else cmap return cmap def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) if im.shape[:-1] != ann.shape: raise ValueError('First two dimensions of `im` and `ann` must match') if im.shape[-1] != 3: raise ValueError('im must have three channels at the 3 dimension') colors = colors or _pascal_color_map() colors = np.asarray(colors, dtype=np.uint8) mask = colors[ann] fg = im * alpha + (1 - alpha) * mask img = im.copy() img[ann > 0] = fg[ann > 0] if contour_thickness: # pragma: no cover import cv2 for obj_id in np.unique(ann[ann > 0]): contours = cv2.findContours((ann == obj_id).astype( np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), contour_thickness) return img def generate_obj_proposals(davis_root, subset, num_proposals, save_path): dataset = DAVIS(davis_root, subset=subset, codalab=True) for seq in dataset.get_sequences(): save_dir = os.path.join(save_path, seq) if os.path.exists(save_dir): continue all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) img_size = all_gt_masks.shape[2:] num_rows = int(np.ceil(np.sqrt(num_proposals))) proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() ii = 0 prev_h, prev_w = 0, 0 for h in height_slices[1:]: for w in width_slices[1:]: proposals[ii, :, prev_h:h, prev_w:w] = 1 prev_w = w ii += 1 if ii == num_proposals: break prev_h, prev_w = h, 0 if ii == num_proposals: break os.makedirs(save_dir, exist_ok=True) for i, mask_id in enumerate(all_masks_id): mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path): dataset = DAVIS(davis_root, subset=subset, codalab=True) for seq in dataset.get_sequences(): gt_masks, all_masks_id = dataset.get_all_masks(seq, True) obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) gt_masks = gt_masks[obj_swap, ...] save_dir = os.path.join(save_path, seq) os.makedirs(save_dir, exist_ok=True) for i, mask_id in enumerate(all_masks_id): mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) def color_map(N=256, normalized=False): def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) dtype = 'float32' if normalized else 'uint8' cmap = np.zeros((N, 3), dtype=dtype) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << 7-j) g = g | (bitget(c, 1) << 7-j) b = b | (bitget(c, 2) << 7-j) c = c >> 3 cmap[i] = np.array([r, g, b]) cmap = cmap/255 if normalized else cmap return cmap def save_mask(mask, img_path): if np.max(mask) > 255: raise ValueError('Maximum id pixel value is 255') mask_img = Image.fromarray(mask.astype(np.uint8)) mask_img.putpalette(color_map().flatten().tolist()) mask_img.save(img_path) def db_statistics(per_frame_values): """ Compute mean,recall and decay from per-frame evaluation. Arguments: per_frame_values (ndarray): per-frame evaluation Returns: M,O,D (float,float,float): return evaluation statistics: mean,recall,decay. """ # strip off nan values with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) M = np.nanmean(per_frame_values) O = np.nanmean(per_frame_values > 0.5) N_bins = 4 ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 ids = ids.astype(np.uint8) D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) return M, O, D def list_files(dir, extension=".png"): return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] def force_symlink(file1, file2): try: os.symlink(file1, file2) except OSError as e: if e.errno == errno.EEXIST: os.remove(file2) os.symlink(file1, file2) ================================================ FILE: utils/palette.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import matplotlib.cm as cm import numpy as np from PIL import ImagePalette def colormap(N=256): def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) dtype = 'uint8' cmap = [] for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << 7-j) g = g | (bitget(c, 1) << 7-j) b = b | (bitget(c, 2) << 7-j) c = c >> 3 cmap.append((r, g, b)) return cmap def apply_cmap(masks_pred, cmap): canvas = np.zeros((masks_pred.shape[0], masks_pred.shape[1], 3)) for label in np.unique(masks_pred): canvas[masks_pred == label, :] = cmap[label] return canvas #np.transpose(canvas, [2,0,1]) def create_palette(colormap, num): cmap = cm.get_cmap(colormap) palette = ImagePalette.ImagePalette() for n in range(num): val = n / num rgb = [int(255*x) for x in cmap(val)[:-1]] palette.getcolor(tuple(rgb)) return palette def custom_palette(nclasses, cname="rainbow"): cmap = cm.get_cmap(cname, nclasses) return cmap ================================================ FILE: utils/palette_davis.py ================================================ palette_str = '''0 0 0 128 0 0 0 128 0 128 128 0 0 0 128 128 0 128 0 128 128 128 128 128 64 0 0 191 0 0 64 128 0 191 128 0 64 0 128 191 0 128 64 128 128 191 128 128 0 64 0 128 64 0 0 191 0 128 191 0 0 64 128 128 64 128 22 22 22 23 23 23 24 24 24 25 25 25 26 26 26 27 27 27 28 28 28 29 29 29 30 30 30 31 31 31 32 32 32 33 33 33 34 34 34 35 35 35 36 36 36 37 37 37 38 38 38 39 39 39 40 40 40 41 41 41 42 42 42 43 43 43 44 44 44 45 45 45 46 46 46 47 47 47 48 48 48 49 49 49 50 50 50 51 51 51 52 52 52 53 53 53 54 54 54 55 55 55 56 56 56 57 57 57 58 58 58 59 59 59 60 60 60 61 61 61 62 62 62 63 63 63 64 64 64 65 65 65 66 66 66 67 67 67 68 68 68 69 69 69 70 70 70 71 71 71 72 72 72 73 73 73 74 74 74 75 75 75 76 76 76 77 77 77 78 78 78 79 79 79 80 80 80 81 81 81 82 82 82 83 83 83 84 84 84 85 85 85 86 86 86 87 87 87 88 88 88 89 89 89 90 90 90 91 91 91 92 92 92 93 93 93 94 94 94 95 95 95 96 96 96 97 97 97 98 98 98 99 99 99 100 100 100 101 101 101 102 102 102 103 103 103 104 104 104 105 105 105 106 106 106 107 107 107 108 108 108 109 109 109 110 110 110 111 111 111 112 112 112 113 113 113 114 114 114 115 115 115 116 116 116 117 117 117 118 118 118 119 119 119 120 120 120 121 121 121 122 122 122 123 123 123 124 124 124 125 125 125 126 126 126 127 127 127 128 128 128 129 129 129 130 130 130 131 131 131 132 132 132 133 133 133 134 134 134 135 135 135 136 136 136 137 137 137 138 138 138 139 139 139 140 140 140 141 141 141 142 142 142 143 143 143 144 144 144 145 145 145 146 146 146 147 147 147 148 148 148 149 149 149 150 150 150 151 151 151 152 152 152 153 153 153 154 154 154 155 155 155 156 156 156 157 157 157 158 158 158 159 159 159 160 160 160 161 161 161 162 162 162 163 163 163 164 164 164 165 165 165 166 166 166 167 167 167 168 168 168 169 169 169 170 170 170 171 171 171 172 172 172 173 173 173 174 174 174 175 175 175 176 176 176 177 177 177 178 178 178 179 179 179 180 180 180 181 181 181 182 182 182 183 183 183 184 184 184 185 185 185 186 186 186 187 187 187 188 188 188 189 189 189 190 190 190 191 191 191 192 192 192 193 193 193 194 194 194 195 195 195 196 196 196 197 197 197 198 198 198 199 199 199 200 200 200 201 201 201 202 202 202 203 203 203 204 204 204 205 205 205 206 206 206 207 207 207 208 208 208 209 209 209 210 210 210 211 211 211 212 212 212 213 213 213 214 214 214 215 215 215 216 216 216 217 217 217 218 218 218 219 219 219 220 220 220 221 221 221 222 222 222 223 223 223 224 224 224 225 225 225 226 226 226 227 227 227 228 228 228 229 229 229 230 230 230 231 231 231 232 232 232 233 233 233 234 234 234 235 235 235 236 236 236 237 237 237 238 238 238 239 239 239 240 240 240 241 241 241 242 242 242 243 243 243 244 244 244 245 245 245 246 246 246 247 247 247 248 248 248 249 249 249 250 250 250 251 251 251 252 252 252 253 253 253 254 254 254 255 255 255''' import numpy as np tensor = np.array([[float(x)/255 for x in line.split()] for line in palette_str.split('\n')]) from matplotlib.colors import ListedColormap palette = ListedColormap(tensor, 'davis') ================================================ FILE: utils/stat_manager.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ class StatManager(object): def __init__(self): self.func_keys = {} self.vals = {} self.vals_count = {} self.formats = {} def items(self): curr_vals = {} for k in self.vals: if self.has_vals(k): curr_vals[k] = self.summarize_key(k) return curr_vals.items() def reset(self): for k in self.vals: self.vals[k] = 0.0 self.vals_count[k] = 0.0 def add_val(self, key, form="{:4.3f}"): self.vals[key] = 0.0 self.vals_count[key] = 0.0 self.formats[key] = form def get_val(self, key): return self.vals[key], self.vals_count[key] def add_compute(self, key, func, form="{:4.3f}"): self.func_keys[key] = func self.add_val(key) self.formats[key] = form def update_stats(self, key, val, count = 1): if not key in self.vals: self.add_val(key) self.vals[key] += val self.vals_count[key] += count def compute_stats(self, a, b, size = 1): for k, func in self.func_keys.iteritems(): self.vals[k] += func(a, b) self.vals_count[k] += size def has_vals(self, k): return k in self.vals_count and \ self.vals_count[k] > 0 def summarize_key(self, k): if self.has_vals(k) and abs(self.vals[k]) > 0.: return self.vals[k] / self.vals_count[k] else: return 0 def summarize(self, epoch = 0, verbose = True): if verbose: out = "\tEpoch[{:03d}]".format(epoch) for k in self.vals: if self.has_vals(k): out += (" / {} " + self.formats[k]).format(k, self.summarize_key(k)) print(out) ================================================ FILE: utils/sys_tools.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import os def check_dir(base_path, name): """Make sure the directory exists""" # create the directory fullpath = os.path.join(base_path, name) if not os.path.exists(fullpath): os.makedirs(fullpath) return fullpath ================================================ FILE: utils/timer.py ================================================ """ Copyright (c) 2021 TU Darmstadt Author: Nikita Araslanov License: Apache License 2.0 """ import time class Timer: def __init__(self, starting_msg = None): self.start = time.time() self.stage_start = self.start if starting_msg is not None: print(starting_msg, time.ctime(time.time())) def update_progress(self, progress): self.elapsed = time.time() - self.start self.est_total = self.elapsed / progress self.est_remaining = self.est_total - self.elapsed self.est_finish = int(self.start + self.est_total) def stage(self, msg=None): t = self.get_stage_elapsed() self.reset_stage() if not msg is None: print("{:4.3f} elapsed: {}".format(t, msg)) return t def str_est_finish(self): return str(time.ctime(self.est_finish)) def get_stage_elapsed(self): return time.time() - self.stage_start def reset_stage(self): self.stage_start = time.time()