[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n*.pyc\n*.sw*\ndata/\nlibs/\nmodels/pretrained\nlogs\nsnapshots\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2021 TU Darmstadt\n\n   Author: Nikita Araslanov\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Dense Unsupervised Learning for Video Segmentation\n\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Framework](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?&logo=PyTorch&logoColor=white)](https://pytorch.org/)\n\nThis repository contains the official implementation of our paper:\n\n**Dense Unsupervised Learning for Video Segmentation**<br>\n[Nikita Araslanov](https://arnike.github.io), [Simone Schaub-Mayer](https://schaubsi.github.io) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/visinf/team_members/sroth/sroth.en.jsp)<br>\nTo appear at NeurIPS*2021. [[paper](https://openreview.net/pdf?id=i8kfkuiCJCI)] [[supp](https://openreview.net/attachment?id=i8kfkuiCJCI&name=supplementary_material)] [[talk](https://youtu.be/tSBWZ6nYld0)] [[example results](https://youtu.be/BqVtZJSLOzg)] [[arXiv](https://arxiv.org/abs/2111.06265)]\n\n| <img src=\"assets/examples.gif\" alt=\"drawing\" width=\"420\"/><br> |\n|:--:|\n| <p align=\"left\">We efficiently learn spatio-temporal correspondences  <br> without any supervision, and achieve state-of-the-art <br>accuracy of video object segmentation.</p> |\n\n\nContact: Nikita Araslanov *fname.lname* (at) visinf.tu-darmstadt.de\n\n\n---\n\n## Installation\n**Requirements.** To reproduce our results, we recommend Python >=3.6, PyTorch >=1.4, CUDA >=10.0. At least one Titan X GPUs (12GB) or equivalent is required.\nThe code was primarily developed under PyTorch 1.8 on a single A100 GPU.\n\nThe following steps will set up a local copy of the repository.\n\n1. Create conda environment:\n```\nconda create --name dense-ulearn-vos\nsource activate dense-ulearn-vos\n```\n\n2. Install PyTorch >=1.4 (see [PyTorch instructions](https://pytorch.org/get-started/locally/)). For example on Linux, run:\n\n```\nconda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch\n```\n\n3. Install the dependencies:\n```\npip install -r requirements.txt\n```\n\n4. Download the data:\n\n| Dataset | Website | Target directory with video sequences |\n|:-:|:-:|:--|\n| YouTube-VOS | [Link](https://competitions.codalab.org/competitions/19544#participate-get-data) | `data/ytvos/train/JPEGImages/` |\n| OxUvA | [Link](https://oxuva.github.io/long-term-tracking-benchmark/) | `data/OxUvA/images/dev/` |\n| TrackingNet | [Link](https://github.com/SilvioGiancola/TrackingNet-devkit) | `data/tracking/train/jpegs/` |\n| Kinetics-400 | [Link](https://deepmind.com/research/open-source/kinetics) | `data/kinetics400/video_jpeg/train/` |\n\nThe last column in this table specifies a path to subdirectories (relative to the project root) containing images of video frames.\nYou can obviously use a different path structure.\nIn this case, you will need to adjust the paths in `data/filelists/` for every dataset accordingly.\n\n5. Download filelists:\n```\ncd data/filelists\nbash download.sh\n```\nThis will download lists of training and validation paths for all datasets.\n\n## Training\nWe following bash script will train a ResNet-18 model from scratch on one of the four supported datasets (see above):\n```\nbash ./launch/train.sh [ytvos|oxuva|track|kinetics]\n```\n\nWe also provide our final models for download.\n\n| Dataset | Mean J&F (DAVIS-2017) | Link | MD5 |\n|---|:-:|:--:|---|\n| OxUvA | 65.3 | [oxuva_e430_res4.pth (132M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/oxuva_e430_res4.pth) | `af541[...]d09b3` |\n| YouTube-VOS | 69.3 | [ytvos_e060_res4.pth (132M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/ytvos_e060_res4.pth) | `c3ae3[...]55faf` |\n| TrackingNet | 69.4 | [trackingnet_e088_res4.pth (88M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/trackingnet_e088_res4.pth) | `3e7e9[...]95fa9` |\n| Kinetics-400 | 68.7 | [kinetics_e026_res4.pth (88M)](https://download.visinf.tu-darmstadt.de/data/2021-neurips-araslanov-vos/snapshots/kinetics_e026_res4.pth) | `086db[...]a7d98` |\n\n\n## Inference and evaluation\n\n### Inference\n\nTo run the inference use `launch/infer_vos.sh`:\n```\nbash ./launch/infer_vos.sh [davis|ytvos]\n```\nThe first argument selects the validation dataset to use (`davis` for DAVIS-2017; `ytvos` for YouTube-VOS).\nThe bash variables declared in the script further help to set up the paths for reading the data and the pre-trained models as well as the output directory:\n* `EXP`, `RUN_ID` and `SNAPSHOT` determine the pre-trained model to load.\n* `VER` specifies a suffix for the output directory (in case you would like to experiment with different configurations for label propagation).\nPlease, refer to `launch/infer_vos.sh` for their usage.\n\nThe inference script will create two directories with the result: `[res3|res4|key]_vos` and `[res3|res4|key]_vis`, where the prefix corresponds to the codename of the output CNN layer used in the evaluation (selected in `infer_vos.sh` using `KEY` variable).\nThe `vos`-directory contains the segmentation result ready for evaluation; the `vis`-directory produces the results for visualisation purposes.\nYou can optionally disable generating the visualisation by setting `VERBOSE=False` in `infer_vos.py`.\n\n\n### Evaluation: DAVIS-2017\n\nPlease use the official [evaluation package](https://github.com/davisvideochallenge/davis2017-evaluation).\nInstall the repository, then simply run:\n```\npython evaluation_method.py --task semi-supervised --davis_path data/davis2017 --results_path <path-to-vos-directory>\n```\n\n### Evaluation: YouTube-VOS 2018\nPlease use the official [CodaLab evaluation server](https://competitions.codalab.org/competitions/19544#participate-submit_results).\nTo create the submission, rename the `vos`-directory to `Annotations` and compress it to `Annotations.zip` for uploading.\n\n## Acknowledgements\n\nWe thank PyTorch contributors and [Allan Jabri](https://ajabri.github.io) for releasing [their implementation](https://github.com/ajabri/videowalk) of the label propagation.\n\n## Citation\nWe hope you find our work useful. If you would like to acknowledge it in your project, please use the following citation:\n```\n@inproceedings{Araslanov:2021:DUL,\n  author    = {Araslanov, Nikita and Simone Schaub-Mayer and Roth, Stefan},\n  title     = {Dense Unsupervised Learning for Video Segmentation},\n  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},\n  volume    = {34},\n  year = {2021}\n}\n```\n"
  },
  {
    "path": "base_trainer.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\nimport torch\nimport math\n\nimport numpy as np\nimport torch.nn.functional as F\nimport torchvision.utils as vutils\n\nfrom torch.utils.tensorboard import SummaryWriter\nfrom torch.optim.optimizer import Optimizer\n\nfrom utils.checkpoints import Checkpoint\nfrom utils.palette_davis import palette as palette_davis\n\nfrom PIL import Image\n\nfrom matplotlib import cm\n\nclass BaseTrainer(object):\n\n    def __init__(self, args, cfg):\n        self.args = args\n        self.cfg = cfg\n        self.start_epoch = 0\n        self.best_score = -1e16\n        self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 3)\n\n        logdir = os.path.join(args.logdir, 'train')\n        self.writer = SummaryWriter(logdir)\n\n    def checkpoint_best(self, score, epoch, temp):\n\n        if score > self.best_score:\n            print(\">>> Saving checkpoint with score {:3.2e}, epoch {}\".format(score, epoch))\n            self.best_score = score\n            self.checkpoint.checkpoint(score, epoch, temp)\n\n            return True\n\n        return False\n\n    @staticmethod\n    def get_optim(params, cfg):\n\n        if not hasattr(torch.optim, cfg.OPT):\n            print(\"Optimiser {} not supported\".format(cfg.OPT))\n            raise NotImplementedError\n\n        optim = getattr(torch.optim, cfg.OPT)\n\n        if cfg.OPT == 'Adam':\n            print(\"Using Adam >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}\".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY))\n            upd = torch.optim.Adam(params, lr=cfg.LR, \\\n                                   betas=(cfg.BETA1, 0.999), \\\n                                   weight_decay=cfg.WEIGHT_DECAY)\n        elif cfg.OPT == 'SGD':\n            print(\"Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}\".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY))\n            upd = torch.optim.SGD(params, lr=cfg.LR, \\\n                                  momentum=cfg.MOMENTUM, \\\n                                  nesterov=cfg.OPT_NESTEROV, \\\n                                  weight_decay=cfg.WEIGHT_DECAY)\n\n        else:\n            upd = optim(params, lr=cfg.LR)\n\n        upd.zero_grad()\n\n        return upd\n\n    @staticmethod\n    def set_lr(optim, lr):\n        for param_group in optim.param_groups:\n            param_group['lr'] = lr\n\n    def _downsize(self, x, mode=\"bilinear\"):\n        x = x.float()\n        if x.dim() == 3:\n            x = x.unsqueeze(1)\n\n        scale = min(*self.cfg.TB.IM_SIZE) / min(x.shape[-1], x.shape[-2])\n        if mode == \"nearest\":\n            x = F.interpolate(x, scale_factor=scale, mode=\"nearest\")\n        else:\n            x = F.interpolate(x, scale_factor=scale, mode=mode, align_corners=True)\n\n        return x.squeeze(1)\n\n    def _visualise_seg(self, epoch, outs, writer, tag, S = 5):\n\n        def with_frame(image, mask, alpha=0.3):\n            return alpha * image + (1 - alpha) * mask\n\n        frames = outs[\"frames\"][::S]\n        frames_norm = self.denorm(frames.cpu().clone())\n        frames_down = self._downsize(frames_norm)\n        T,C,h,w = frames_down.shape\n\n        visuals = []\n        visuals.append(frames_down)\n\n        if \"masks_gt\" in outs:\n            mask_rgb_gt = self._apply_cmap(outs[\"masks_gt\"][::S].cpu(), palette_davis, rand=False)\n            mask_rgb_gt = self._downsize(mask_rgb_gt)\n            mask_rgb_gt = with_frame(frames_down, mask_rgb_gt)\n            visuals.append(mask_rgb_gt)\n\n        mask_rgb_idx = self._apply_cmap(outs[\"masks_pred_idx\"][::S].cpu(), palette_davis, rand=False)\n        mask_rgb_idx = self._downsize(mask_rgb_idx)\n        mask_rgb_idx = with_frame(frames_down, mask_rgb_idx)\n        visuals.append(mask_rgb_idx)\n\n        #if \"masks_pred_conf\" in outs:\n        conf = self._downsize(outs[\"masks_pred_conf\"][::S].cpu())\n        conf_rgb = self._error_rgb(conf, cm.get_cmap(\"plasma\"), frames_down, 0.3)\n        visuals.append(conf_rgb)\n\n        visuals = [x.float() for x in visuals]\n        visuals = torch.cat(visuals, -1)\n\n        self._visualise_grid(writer, visuals, epoch, tag)\n\n    def _visualise(self, epoch, outs, T, writer, tag):\n        visuals = []\n\n        def overlay(mask, image, alpha=0.3):\n            return alpha * image + (1 - alpha) * mask\n\n        frames_orig = outs[\"frames_orig\"]\n        frames_orig = self.denorm(frames_orig.cpu().clone())\n        frames_orig = self._downsize(frames_orig)\n        visuals.append(frames_orig)\n\n        frames = outs[\"frames\"]\n        frames_norm = self.denorm(frames.cpu().clone())\n        frames_down = self._downsize(frames_norm)\n\n        if \"grid_mask\" in outs:\n            val_mask = outs[\"grid_mask\"]\n            val_mask = self._downsize(val_mask)\n            val_mask = val_mask.unsqueeze(1).expand(-1,3,-1,-1).cpu()\n\n            val_mask = overlay(val_mask, frames_orig)\n            visuals.append(val_mask)\n\n        if \"map_target\" in outs:\n            val = outs[\"map_target\"]\n            val = self._apply_cmap(val)\n            val = self._downsize(val, \"nearest\")\n            val = overlay(val, frames_down)\n            visuals.append(val)\n\n        if \"map_soft\" in outs:\n            val = outs[\"map_soft\"]\n            val = self._mask_rgb(val)\n            val = self._downsize(val)\n            visuals.append(val)\n\n        visuals.append(frames_down)\n\n        frames2 = outs[\"frames2\"]\n        frames2_norm = self.denorm(frames2.cpu().clone())\n        frames2_down = self._downsize(frames2_norm)\n        visuals.append(frames2_down)\n\n        if \"map\" in outs:\n            val = outs[\"map\"]\n            val = self._apply_cmap(val)\n            val = self._downsize(val, \"nearest\").cpu()\n            val = overlay(val, frames2_down)\n            visuals.append(val)\n\n        if \"map_target_soft\" in outs:\n            val = outs[\"map_target_soft\"]\n            val = self._mask_rgb(val)\n            val = self._downsize(val)\n            visuals.append(val)\n\n        # embedding error mask\n        if \"error_map\" in outs:\n            err_mask = outs[\"error_map\"]\n            err_mask = (err_mask - err_mask.min()) / (err_mask.max() - err_mask.min() + 1e-8)\n            err_mask_rgb = self._error_rgb(err_mask, cmap=cm.get_cmap(\"plasma\"), alpha=0.5)\n            err_mask_rgb = self._downsize(err_mask_rgb)\n            visuals.append(err_mask_rgb)\n\n        if \"aff_mask1\" in outs:\n            aff_mask = outs[\"aff_mask1\"].unsqueeze(1).expand(-1,3,-1,-1).cpu()\n            aff_mask = self._downsize(aff_mask)\n\n            aff_frames = frames_orig.clone()\n            aff_frames[::T] = overlay(aff_mask, aff_frames[::T], 0.5)\n            visuals.append(aff_frames)\n\n            aff_mask1 = self._error_rgb(outs[\"aff11\"], cm.get_cmap(\"inferno\"))\n            aff_mask1 = self._downsize(aff_mask1)\n            aff_mask1 = overlay(aff_mask1, frames_orig, 0.3)\n            visuals.append(aff_mask1)\n\n            aff_mask2 = self._error_rgb(outs[\"aff12\"], cm.get_cmap(\"inferno\"))\n            aff_mask2 = self._downsize(aff_mask2)\n            aff_mask2 = overlay(aff_mask2, frames2_down, 0.3)\n            visuals.append(aff_mask2)\n\n        if \"aff_mask2\" in outs:\n            aff_mask = outs[\"aff_mask2\"].unsqueeze(1).expand(-1,3,-1,-1).cpu()\n            aff_mask = self._downsize(aff_mask)\n\n            aff_frames = frames_down.clone()\n            aff_frames[::T] = overlay(aff_mask, aff_frames[::T], 0.5)\n            visuals.append(aff_frames)\n\n            aff_mask1 = self._error_rgb(outs[\"aff21\"], cm.get_cmap(\"inferno\"))\n            aff_mask1 = self._downsize(aff_mask1)\n            aff_mask1 = overlay(aff_mask1, frames_orig, 0.3)\n            visuals.append(aff_mask1)\n\n            aff_mask2 = self._error_rgb(outs[\"aff22\"], cm.get_cmap(\"inferno\"))\n            aff_mask2 = self._downsize(aff_mask2)\n            aff_mask2 = overlay(aff_mask2, frames2_down, 0.3)\n            visuals.append(aff_mask2)\n\n        visuals = [x.cpu().float() for x in visuals]\n        visuals = torch.cat(visuals, -1)\n\n        self._visualise_grid(writer, visuals, epoch, tag, 4 * T)\n\n    def save_vis_batch(self, key, batch):\n\n        if self.vis_batch is None:\n            self.vis_batch = {}\n\n        if key in self.vis_batch:\n            return\n\n        batch_items = []\n        for el in batch:\n            el = el.clone().cpu() if torch.is_tensor(el) else el\n            batch_items.append(el)\n\n        self.vis_batch[key] = batch_items\n\n    def has_vis_batch(self, key):\n        return (not self.vis_batch is None and \\\n                    key in self.vis_batch)\n\n    def _mask_rgb(self, masks, image_norm=None, palette=None, alpha=0.3):\n\n        if palette is None:\n            palette = self.loader.dataset.palette\n\n        # visualising masks\n        masks_conf, masks_idx = torch.max(masks, 1)\n        masks_conf = masks_conf - F.relu(masks_conf - 1, 0)\n\n        masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), palette, mask_conf=masks_conf.cpu())\n        if not image_norm is None:\n            return alpha * image_norm + (1 - alpha) * masks_idx_rgb\n\n        return masks_idx_rgb\n\n    def _apply_cmap(self, mask_idx, palette=None, mask_conf=None, rand=True):\n\n        if palette is None:\n            palette = self.loader.dataset.palette\n\n        ignore_mask = (mask_idx == -1).cpu()\n\n        # cycle\n        if rand:\n            memsize = self.cfg.TRAIN.BATCH_SIZE * self.cfg.MODEL.GRID_SIZE**2\n            mask_idx = ((mask_idx + 1) * 123) % memsize\n\n        # convert mask to RGB\n        mask = mask_idx.cpu().numpy().astype(np.uint32)\n        mask_rgb = palette(mask)\n        mask_rgb = torch.from_numpy(mask_rgb[:,:,:,:3])\n        mask_rgb[ignore_mask] *= 0\n        mask_rgb = mask_rgb.permute(0,3,1,2)\n\n        if not mask_conf is None:\n            # entropy\n            mask_rgb *= mask_conf[:, None, :, :]\n\n        return mask_rgb\n\n    def _error_rgb(self, error_mask, cmap = cm.get_cmap('jet'), image=None, alpha=0.3):\n        error_np = error_mask.cpu().numpy()\n\n        # remove alpha channel\n        error_rgb = cmap(error_np)[:, :, :, :3]\n        error_rgb = np.transpose(error_rgb, (0,3,1,2))\n        error_rgb = torch.from_numpy(error_rgb)\n\n        if not image is None:\n            return alpha * image + (1 - alpha) * error_rgb\n\n        return error_rgb\n\n    def _visualise_grid(self, writer, x_all, t, tag, T=1):\n        \n        # adding the labels to images\n        bs, ch, h, w = x_all.size()\n        x_all_new = torch.zeros(T, ch, h, w)\n        for b in range(bs):\n\n            x_all_new[b % T] = x_all[b]\n\n            if (b + 1) % T == 0:\n                summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9).numpy()\n                writer.add_image(tag + \"_{:02d}\".format(b // T), summary_grid, t)\n                x_all_new.zero_()\n\n    def visualise_results(self, epoch, writer, tag, step_func):\n        # visualising\n        self.net.eval()\n\n        with torch.no_grad():\n            step_func(epoch, self.vis_batch[tag], \\\n                      train=False, visualise=True, \\\n                      writer=writer, tag=tag)\n"
  },
  {
    "path": "configs/kinetics.yaml",
    "content": "DATASET:\n  ROOT: \"data\"\n  SMALLEST_RANGE: [256,256]\n  RND_CROP: False\n  RND_ZOOM: True\n  RND_ZOOM_RANGE: [.5, 1.]\n  GUIDED_HFLIP: True\n  VIDEO_LEN: 5\n  FRAME_GAP: 10\n  MP4: True\nTRAIN:\n  BATCH_SIZE: 16\n  NUM_EPOCHS: 80\n  TASK: \"kinetics400\"\nMODEL:\n  LR: 0.001\n  CE_REF: 0.1\n  OPT: \"SGD\"\n  LR_SCHEDULER: \"step\"\n  LR_GAMMA: 0.5\n  LR_STEP: 20\n  WEIGHT_DECAY: 0.0005\nLOG:\n  ITER_TRAIN: 8\n  ITER_VAL: 2\nTEST:\n  KNN: 10\n  CXT_SIZE: 20\n  RADIUS: 12\n  TEMP: 0.05\nTB:\n  IM_SIZE: [196, 196]\n"
  },
  {
    "path": "configs/oxuva.yaml",
    "content": "DATASET:\n  ROOT: \"data\"\n  SMALLEST_RANGE: [256, 320]\n  RND_CROP: True\n  RND_ZOOM: True\n  RND_ZOOM_RANGE: [.5, 1.]\n  GUIDED_HFLIP: True\n  VIDEO_LEN: 5\n  FRAME_GAP: 10\nTRAIN:\n  BATCH_SIZE: 16\n  NUM_EPOCHS: 500\n  TASK: \"OxUvA\"\nMODEL:\n  LR: 0.0001\n  OPT: \"Adam\"\n  LR_SCHEDULER: \"step\"\n  LR_GAMMA: 0.5\n  LR_STEP: 100\n  WEIGHT_DECAY: 0.0005\nLOG:\n  ITER_TRAIN: 20\n  ITER_VAL: 10\nTEST:\n  KNN: 10\n  CXT_SIZE: 20\n  RADIUS: 12\n  TEMP: 0.05\nTB:\n  IM_SIZE: [196, 196]\n"
  },
  {
    "path": "configs/tracknet.yaml",
    "content": "DATASET:\n  ROOT: \"data\"\n  SMALLEST_RANGE: [256,256]\n  RND_CROP: False\n  RND_ZOOM: True\n  RND_ZOOM_RANGE: [.5, 1.]\n  GUIDED_HFLIP: True\n  VIDEO_LEN: 5\n  FRAME_GAP: 10\nTRAIN:\n  BATCH_SIZE: 16\n  NUM_EPOCHS: 250\n  TASK: \"TrackingNet\"\n  BLOCK_BN: False\n  STOP_GRAD: True\nMODEL:\n  LR: 0.001\n  CE_REF: 0.1\n  OPT: \"SGD\"\n  LR_SCHEDULER: \"step\"\n  LR_GAMMA: 0.9\n  LR_STEP: 20\n  WEIGHT_DECAY: 0.0005\nLOG:\n  ITER_TRAIN: 10\n  ITER_VAL: 2\nTEST:\n  KNN: 10\n  CXT_SIZE: 20\n  RADIUS: 12\n  TEMP: 0.05\nTB:\n  IM_SIZE: [196, 196]\n"
  },
  {
    "path": "configs/ytvos.yaml",
    "content": "DATASET:\n  ROOT: \"data\"\n  SMALLEST_RANGE: [256, 320]\n  RND_CROP: True\n  RND_ZOOM: True\n  RND_ZOOM_RANGE: [.5, 1.]\n  GUIDED_HFLIP: True\n  VIDEO_LEN: 5\n  FRAME_GAP: 2\nTRAIN:\n  BATCH_SIZE: 16\n  NUM_EPOCHS: 500\n  TASK: \"YTVOS\"\nMODEL:\n  LR: 0.0001\n  OPT: \"Adam\"\n  LR_SCHEDULER: \"step\"\n  LR_GAMMA: 0.5\n  LR_STEP: 100\n  WEIGHT_DECAY: 0.0005\nLOG:\n  ITER_TRAIN: 20\n  ITER_VAL: 10\nTEST:\n  KNN: 10\n  CXT_SIZE: 20\n  RADIUS: 12\n  TEMP: 0.05\nTB:\n  IM_SIZE: [196, 196]\n"
  },
  {
    "path": "core/__init__.py",
    "content": ""
  },
  {
    "path": "core/config.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom __future__ import unicode_literals\n\n\nimport yaml\nimport six\nimport os\nimport os.path as osp\nimport copy\nfrom ast import literal_eval\n\nimport numpy as np\nfrom packaging import version\n\nfrom utils.collections import AttrDict\n\n__C = AttrDict()\n# Consumers can get config by:\n#   from fast_rcnn_config import cfg\ncfg = __C\n\n__C.NUM_GPUS = 1\n# Random note: avoid using '.ON' as a config key since yaml converts it to True;\n# prefer 'ENABLED' instead\n\n# ---------------------------------------------------------------------------- #\n# Training schedule options\n# ---------------------------------------------------------------------------- #\n__C.TRAIN = AttrDict()\n__C.TRAIN.BATCH_SIZE = 20\n__C.TRAIN.NUM_EPOCHS = 15\n__C.TRAIN.NUM_WORKERS = 4\n__C.TRAIN.TASK = \"YTVOS\"\n__C.TRAIN.BLOCK_BN = False\n__C.TRAIN.STOP_GRAD = True\n\n# ---------------------------------------------------------------------------- #\n# Dataset options (+ augmentation options)\n# ---------------------------------------------------------------------------- #\n__C.DATASET = AttrDict()\n\n__C.DATASET.CROP_SIZE = [256,256]\n__C.DATASET.SMALLEST_RANGE = [256,320]\n__C.DATASET.RND_CROP  = False\n__C.DATASET.RND_HFLIP = True\n__C.DATASET.MP4 = False\n\n__C.DATASET.RND_ZOOM = False\n__C.DATASET.RND_ZOOM_RANGE = [.5,1.]\n__C.DATASET.GUIDED_HFLIP = True\n\n__C.DATASET.ROOT = \"data\"\n__C.DATASET.VIDEO_LEN = 5\n__C.DATASET.FRAME_GAP = 5\n__C.DATASET.VAL_FRAME_GAP = 2\n\n# size of the augmentation\n# (how many augmented copies to create for 1 image)\n__C.DATASET.NUM_CLASSES = 7\n\n# inference-time parameters\n__C.TEST = AttrDict()\n__C.TEST.RADIUS = 12\n__C.TEST.TEMP = 0.05\n__C.TEST.KNN = 10\n__C.TEST.CXT_SIZE = 20\n__C.TEST.INPUT_SIZE = -1\n__C.TEST.KEY = \"res4\"\n\n# ---------------------------------------------------------------------------- #\n# Model options\n# ---------------------------------------------------------------------------- #\n__C.MODEL = AttrDict()\n__C.MODEL.ARCH = 'resnet18'\n__C.MODEL.LR_SCHEDULER = 'poly'\n__C.MODEL.LR_STEP = 5\n__C.MODEL.LR_GAMMA = 0.1 # divide by 10 every LR_STEP epochs\n__C.MODEL.LR_POWER = 1.0\n__C.MODEL.LR_SCHED_USE_EPOCH = True\n__C.MODEL.OPT = 'Adam'\n__C.MODEL.OPT_NESTEROV = False\n__C.MODEL.LR = 3e-4\n__C.MODEL.BETA1 = 0.5\n__C.MODEL.MOMENTUM = 0.9\n__C.MODEL.WEIGHT_DECAY = 1e-5\n__C.MODEL.REMOVE_LAYERS = []\n__C.MODEL.FEATURE_DIM = 128\n__C.MODEL.GRID_SIZE = 8\n__C.MODEL.GRID_SIZE_REF = 4\n__C.MODEL.CE_REF = 0.1\n\n# ---------------------------------------------------------------------------- #\n# Options for refinement\n# ---------------------------------------------------------------------------- #\n__C.LOG = AttrDict()\n__C.LOG.ITER_VAL = 1\n__C.LOG.ITER_TRAIN = 10\n\n# ---------------------------------------------------------------------------- #\n# Model options\n# ---------------------------------------------------------------------------- #\n__C.TB = AttrDict()\n__C.TB.IM_SIZE = (196, 196) # image resolution\n\n# [Infered value]\n__C.CUDA = False\n__C.DEBUG = False\n\n# [Infered value]\n__C.PYTORCH_VERSION_LESS_THAN_040 = False\n\ndef assert_and_infer_cfg(make_immutable=True):\n    \"\"\"Call this function in your script after you have finished setting all cfg\n    values that are necessary (e.g., merging a config from a file, merging\n    command line config options, etc.). By default, this function will also\n    mark the global cfg as immutable to prevent changing the global cfg settings\n    during script execution (which can lead to hard to debug errors or code\n    that's harder to understand than is necessary).\n    \"\"\"\n    if make_immutable:\n        cfg.immutable(True)\n\n\ndef merge_cfg_from_file(cfg_filename):\n    \"\"\"Load a yaml config file and merge it into the global config.\"\"\"\n    with open(cfg_filename, 'r') as f:\n        yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))\n    _merge_a_into_b(yaml_cfg, __C)\n\ncfg_from_file = merge_cfg_from_file\n\n\ndef merge_cfg_from_cfg(cfg_other):\n    \"\"\"Merge `cfg_other` into the global config.\"\"\"\n    _merge_a_into_b(cfg_other, __C)\n\n\ndef merge_cfg_from_list(cfg_list):\n    \"\"\"Merge config keys, values in a list (e.g., from command line) into the\n    global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.\n    \"\"\"\n    assert len(cfg_list) % 2 == 0\n    for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):\n        key_list = full_key.split('.')\n        d = __C\n        for subkey in key_list[:-1]:\n            assert subkey in d, 'Non-existent key: {}'.format(full_key)\n            d = d[subkey]\n        subkey = key_list[-1]\n        assert subkey in d, 'Non-existent key: {}'.format(full_key)\n        value = _decode_cfg_value(v)\n        value = _check_and_coerce_cfg_value_type(\n            value, d[subkey], subkey, full_key\n        )\n        d[subkey] = value\n\ncfg_from_list = merge_cfg_from_list\n\n\ndef _merge_a_into_b(a, b, stack=None):\n    \"\"\"Merge config dictionary a into config dictionary b, clobbering the\n    options in b whenever they are also specified in a.\n    \"\"\"\n    assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict'\n    assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict'\n\n    for k, v_ in a.items():\n        full_key = '.'.join(stack) + '.' + k if stack is not None else k\n        # a must specify keys that are in b\n        if k not in b:\n            raise KeyError('Non-existent config key: {}'.format(full_key))\n\n        v = copy.deepcopy(v_)\n        v = _decode_cfg_value(v)\n        v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)\n\n        # Recursively merge dicts\n        if isinstance(v, AttrDict):\n            try:\n                stack_push = [k] if stack is None else stack + [k]\n                _merge_a_into_b(v, b[k], stack=stack_push)\n            except BaseException:\n                raise\n        else:\n            b[k] = v\n\n\ndef _decode_cfg_value(v):\n    \"\"\"Decodes a raw config value (e.g., from a yaml config files or command\n    line argument) into a Python object.\n    \"\"\"\n    # Configs parsed from raw yaml will contain dictionary keys that need to be\n    # converted to AttrDict objects\n    if isinstance(v, dict):\n        return AttrDict(v)\n    # All remaining processing is only applied to strings\n    if not isinstance(v, six.string_types):\n        return v\n    # Try to interpret `v` as a:\n    #   string, number, tuple, list, dict, boolean, or None\n    try:\n        v = literal_eval(v)\n    # The following two excepts allow v to pass through when it represents a\n    # string.\n    #\n    # Longer explanation:\n    # The type of v is always a string (before calling literal_eval), but\n    # sometimes it *represents* a string and other times a data structure, like\n    # a list. In the case that v represents a string, what we got back from the\n    # yaml parser is 'foo' *without quotes* (so, not '\"foo\"'). literal_eval is\n    # ok with '\"foo\"', but will raise a ValueError if given 'foo'. In other\n    # cases, like paths (v = 'foo/bar' and not v = '\"foo/bar\"'), literal_eval\n    # will raise a SyntaxError.\n    except ValueError:\n        pass\n    except SyntaxError:\n        pass\n    return v\n\n\ndef _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):\n    \"\"\"Checks that `value_a`, which is intended to replace `value_b` is of the\n    right type. The type is correct if it matches exactly or is one of a few\n    cases in which the type can be easily coerced.\n    \"\"\"\n    # The types must match (with some exceptions)\n    type_b = type(value_b)\n    type_a = type(value_a)\n    if type_a is type_b:\n        return value_a\n\n    # Exceptions: numpy arrays, strings, tuple<->list\n    if isinstance(value_b, np.ndarray):\n        value_a = np.array(value_a, dtype=value_b.dtype)\n    elif isinstance(value_b, six.string_types):\n        value_a = str(value_a)\n    elif isinstance(value_a, tuple) and isinstance(value_b, list):\n        value_a = list(value_a)\n    elif isinstance(value_a, list) and isinstance(value_b, tuple):\n        value_a = tuple(value_a)\n    else:\n        raise ValueError(\n            'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '\n            'key: {}'.format(type_b, type_a, value_b, value_a, full_key)\n        )\n    return value_a\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport torch\nfrom torch.utils import data\n\nfrom .dataloader_seg import DataSeg\nfrom .dataloader_video import DataVideo, DataVideoKinetics\n\ndef get_sets(task):\n\n    # fetch the names of data lists depending on the task\n    sets = {}\n    if task == \"OxUvA\":\n        sets[\"train_video\"] = \"train_oxuva\"\n        sets[\"val_video\"] = \"val_davis2017_480p\"\n        sets[\"val_video_seg\"] = \"val2_davis2017_480p\"\n    elif task == \"YTVOS\":\n        sets[\"train_video\"] = \"train_ytvos\"\n        sets[\"val_video\"] = \"val_davis2017_480p\"\n        sets[\"val_video_seg\"] = \"val2_davis2017_480p\"\n    elif task == \"TrackingNet\":\n        sets[\"train_video\"] = \"train_tracking\"\n        sets[\"val_video\"] = \"val_davis2017_480p\"\n        sets[\"val_video_seg\"] = \"val2_davis2017_480p\"\n    elif task == \"kinetics400\":\n        sets[\"train_video\"] = \"train_kinetics400\"\n        sets[\"val_video\"] = \"val_davis2017_480p\"\n        sets[\"val_video_seg\"] = \"val2_davis2017_480p\"\n    else:\n        raise NotImplementedError(\"Dataset '{}' not recognised.\".format(task))\n    \n    return sets\n\ndef get_dataloader(args, cfg, split):\n    assert split in (\"train\", \"val\")\n\n    task = cfg.TRAIN.TASK\n    data_sets = get_sets(cfg.TRAIN.TASK)\n\n    # total batch size: # of GPUs * batch size per GPU\n    batch_size = cfg.TRAIN.BATCH_SIZE\n    kwargs = {'pin_memory': True, 'num_workers': args.workers}\n    print(\"Dataloader: # workers {}\".format(args.workers))\n\n    def _dataloader(dataset, batch_size, shuffle=True, drop_last=False):\n        return data.DataLoader(dataset, batch_size=batch_size, \\\n                               shuffle=shuffle, drop_last=drop_last, **kwargs)\n\n\n    print(\"Split: \", split)\n    if split == \"train\":\n        VideoLoader = DataVideoKinetics if cfg.DATASET.MP4 else DataVideo\n        dataset_video = VideoLoader(cfg, data_sets[\"train_video\"])\n        return _dataloader(dataset_video, batch_size, drop_last=True)\n    else:\n        dataset_video = DataVideo(cfg, data_sets[\"val_video\"], val=True)\n        dataset_video_seg = DataSeg(cfg, data_sets[\"val_video_seg\"])\n\n        return {\"val_video\": _dataloader(dataset_video, batch_size, shuffle=False), \\\n                \"val_video_seg\": _dataloader(dataset_video_seg, 1, shuffle=False)}\n"
  },
  {
    "path": "datasets/dataloader_base.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport torch.utils.data as data\nfrom utils.palette import custom_palette\n\nclass DLBase(data.Dataset):\n\n    def __init__(self, *args, **kwargs):\n        super(DLBase, self).__init__(*args, **kwargs)\n\n        # RGB\n        self.MEAN = [0.485, 0.456, 0.406]\n        self.STD = [0.229, 0.224, 0.225]\n\n        self._init_means()\n\n    def _init_means(self):\n        self.MEAN255 = [255.*x for x in self.MEAN]\n        self.STD255 = [255.*x for x in self.STD]\n\n    def _init_palette(self, num_classes):\n        self.palette = custom_palette(num_classes)\n\n    def get_palette(self):\n        return self.palette\n\n    def remove_labels(self, mask):\n        # Remove labels not in training\n        for ignore_label in self.ignore_labels:\n            mask[mask == ignore_label] = 255\n\n        return mask.long()\n\n"
  },
  {
    "path": "datasets/dataloader_infer.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\nimport torch\n\nfrom PIL import Image\n\nimport numpy as np\nimport torchvision.transforms as tf\n\nfrom .dataloader_base import DLBase\n\n\nclass DataSeg(DLBase):\n\n    def __init__(self, cfg, split, ignore_labels=[], \\\n                 root=os.path.expanduser('./data'), renorm=False):\n\n        super(DataSeg, self).__init__()\n\n        self.cfg = cfg\n        self.root = root\n        self.split = split\n        self.ignore_labels = ignore_labels\n        self._init_palette(self.cfg.DATASET.NUM_CLASSES)\n\n        # train/val/test splits are pre-cut\n        split_fn = os.path.join(self.root, self.split + \".txt\")\n        assert os.path.isfile(split_fn)\n\n        self.sequence_ids = []\n        self.sequence_names = []\n        def add_sequence(name):\n            vlen = len(self.images)\n            assert vlen >= cfg.DATASET.VIDEO_LEN, \\\n                \"Detected video shorter [{}] than training length [{}]\".format(vlen, \\\n                                                                cfg.DATASET.VIDEO_LEN)\n            self.sequence_ids.append(vlen)\n            self.sequence_names.append(name)\n            return vlen\n\n        self.images = []\n        self.masks = []\n        self.flags = []\n\n        token = None\n        with open(split_fn, \"r\") as lines:\n            for line in lines:\n                _flag, _image, _mask = line.strip(\"\\n\").split(' ')\n\n                # save every frame\n                #_flag = 1\n                self.flags.append(int(_flag))\n\n                _image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/'))\n                assert os.path.isfile(_image), '%s not found' % _image\n\n                # each sequence may have a different length\n                # do some book-keeping e.g. to ensure we have\n                # sequences long enough for subsequent sampling\n                _token = _image.split(\"/\")[-2] # parent directory\n                \n                # sequence ID is in the filename\n                #_token = os.path.basename(_image).split(\"_\")[0] \n                if token != _token:\n                    if not token is None:\n                        add_sequence(token)\n                    token = _token\n\n                self.images.append(_image)\n\n                if _mask is None:\n                    self.masks.append(None)\n                else:\n                    _mask = os.path.join(cfg.DATASET.ROOT, _mask.lstrip('/'))\n                    #assert os.path.isfile(_mask), '%s not found' % _mask\n                    self.masks.append(_mask)\n\n        # update the last sequence\n        # returns the total amount of frames\n        add_sequence(token)\n        print(\"Loaded {} sequences\".format(len(self.sequence_ids)))\n\n        # definint data augmentation:\n        print(\"Dataloader: {}\".format(split), \" #\", len(self.images))\n        print(\"\\t {}: no augmentation\".format(split))\n\n        self.tf = tf.Compose([tf.ToTensor(), tf.Normalize(mean=self.MEAN, std=self.STD)])\n        self._num_samples = len(self.images)\n\n    def __len__(self):\n        return len(self.sequence_ids)\n\n    \n    def _mask2tensor(self, mask, num_classes=6):\n        h,w = mask.shape\n        ones = torch.ones(1,h,w)\n        zeros = torch.zeros(num_classes,h,w)\n        \n        max_idx = mask.max()\n        assert max_idx < num_classes, \"{} >= {}\".format(max_idx, num_classes)\n        return zeros.scatter(0, mask[None, ...], ones)\n    \n    def denorm(self, image):\n\n        if image.dim() == 3:\n            assert image.dim() == 3, \"Expected image [CxHxW]\"\n            assert image.size(0) == 3, \"Expected RGB image [3xHxW]\"\n\n            for t, m, s in zip(image, self.MEAN, self.STD):\n                t.mul_(s).add_(m)\n        elif image.dim() == 4:\n            # batch mode\n            assert image.size(1) == 3, \"Expected RGB image [3xHxW]\"\n\n            for t, m, s in zip((0,1,2), self.MEAN, self.STD):\n                image[:, t, :, :].mul_(s).add_(m)\n\n        return image\n\n\n    def __getitem__(self, index):\n        \n        seq_to = self.sequence_ids[index]\n        seq_from = 0 if index == 0 else self.sequence_ids[index - 1]\n\n        image0 = Image.open(self.images[seq_from])\n        w,h = image0.size\n\n        images, masks, fns, flags = [], [], [], []\n        tracks = torch.LongTensor(self.cfg.DATASET.NUM_CLASSES).fill_(-1)\n        masks = torch.LongTensor(self.cfg.DATASET.NUM_CLASSES, h, w).zero_()\n        known_ids = set()\n\n        for t in range(seq_from, seq_to):\n\n            t0 = t - seq_from\n            image = Image.open(self.images[t]).convert('RGB')\n\n            fns.append(os.path.basename(self.images[t].replace(\".jpg\", \"\")))\n            flags.append(self.flags[t])\n\n            if os.path.isfile(self.masks[t]):\n                mask = Image.open(self.masks[t])\n                mask = torch.from_numpy(np.array(mask, np.long, copy=False))\n\n                unique_ids = np.unique(mask)\n                for oid in unique_ids:\n                    if not oid in known_ids:\n                        tracks[oid] = t0\n                        known_ids.add(oid)\n                        masks[oid] = (mask == oid).long()\n            else:\n                mask = Image.new('L', image.size)\n\n            image = self.tf(image)\n            images.append(image)\n\n        images = torch.stack(images, 0)\n        seq_name = self.sequence_names[index]\n        flags = torch.LongTensor(flags)\n\n        return images, images, masks, tracks, len(known_ids), fns, flags, seq_name\n"
  },
  {
    "path": "datasets/dataloader_seg.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\nimport torch\n\nfrom PIL import Image\n\nimport numpy as np\nimport torch.utils.data as data\nimport torchvision.transforms as tf\n\nfrom .dataloader_base import DLBase\n\nclass DataSeg(DLBase):\n\n    def __init__(self, cfg, split, ignore_labels=[], \\\n                 root=os.path.expanduser('./data'), renorm=False):\n\n        super(DataSeg, self).__init__()\n\n        self.cfg = cfg\n        self.root = root\n        self.split = split\n        self.ignore_labels = ignore_labels\n\n        self._init_palette(cfg.DATASET.NUM_CLASSES)\n\n        # train/val/test splits are pre-cut\n        split_fn = os.path.join(self.root, \"filelists\", self.split + \".txt\")\n        assert os.path.isfile(split_fn)\n\n        self.sequence_ids = []\n        self.sequence_names = []\n        def add_sequence(name):\n            vlen = len(self.images)\n            assert vlen >= cfg.DATASET.VIDEO_LEN, \\\n                \"Detected video shorter [{}] than training length [{}]\".format(vlen, \\\n                                                                cfg.DATASET.VIDEO_LEN)\n            self.sequence_ids.append(vlen)\n            self.sequence_names.append(name)\n            return vlen\n\n        self.images = []\n        self.masks = []\n\n        token = None\n        with open(split_fn, \"r\") as lines:\n            for line in lines:\n                _image = line.strip(\"\\n\").split(' ')\n\n                _mask = None\n                if len(_image) == 2:\n                    _image, _mask = _image\n                else:\n                    assert len(_image) == 1\n                    _image = _image[0]\n\n                _image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/'))\n                assert os.path.isfile(_image), '%s not found' % _image\n                self.images.append(_image)\n\n                # each sequence may have a different length\n                # do some book-keeping e.g. to ensure we have\n                # sequences long enough for subsequent sampling\n                _token = _image.split(\"/\")[-2] # parent directory\n                \n                # sequence ID is in the filename\n                #_token = os.path.basename(_image).split(\"_\")[0] \n                if token != _token:\n                    if not token is None:\n                        add_sequence(token)\n                    token = _token\n\n                if _mask is None:\n                    self.masks.append(None)\n                else:\n                    _mask = os.path.join(cfg.DATASET.ROOT, _mask.lstrip('/'))\n                    assert os.path.isfile(_mask), '%s not found' % _mask\n                    self.masks.append(_mask)\n\n        # update the last sequence\n        # returns the total amount of frames\n        add_sequence(token)\n        print(\"Loaded {} sequences\".format(len(self.sequence_ids)))\n\n        # definint data augmentation:\n        print(\"Dataloader: {}\".format(split), \" #\", len(self.images))\n        print(\"\\t {}: no augmentation\".format(split))\n\n        self.tf = tf.Compose([tf.ToTensor(), tf.Normalize(mean=self.MEAN, std=self.STD)])\n        self._num_samples = len(self.images)\n\n    def __len__(self):\n        return len(self.sequence_ids)\n\n    def denorm(self, image):\n\n        if image.dim() == 3:\n            assert image.dim() == 3, \"Expected image [CxHxW]\"\n            assert image.size(0) == 3, \"Expected RGB image [3xHxW]\"\n\n            for t, m, s in zip(image, self.MEAN, self.STD):\n                t.mul_(s).add_(m)\n        elif image.dim() == 4:\n            # batch mode\n            assert image.size(1) == 3, \"Expected RGB image [3xHxW]\"\n\n            for t, m, s in zip((0,1,2), self.MEAN, self.STD):\n                image[:, t, :, :].mul_(s).add_(m)\n\n        return image\n    \n    def _mask2tensor(self, mask, num_classes=6):\n        h,w = mask.shape\n        ones = torch.ones(1,h,w)\n        zeros = torch.zeros(num_classes,h,w)\n        \n        max_idx = mask.max()\n        assert max_idx < num_classes, \"{} >= {}\".format(max_idx, num_classes)\n        return zeros.scatter(0, mask[None, ...], ones)\n\n    def __getitem__(self, index):\n\n        seq_to = self.sequence_ids[index] - 1\n        seq_from = 0 if index == 0 else self.sequence_ids[index-1] - 1\n\n        images, masks = [], []\n        n_obj = 0\n        for _id_ in range(seq_from, seq_to):\n\n            image = Image.open(self.images[_id_]).convert('RGB')\n\n            if self.masks[_id_] is None:\n                mask = Image.new('L', image.size)\n            else:\n                mask = Image.open(self.masks[_id_]) #.convert('L')\n\n            image = self.tf(image)\n            images.append(image)\n\n            mask = torch.from_numpy(np.array(mask, np.long, copy=False))\n            n_obj = max(n_obj, mask.max().item())\n            masks.append(self._mask2tensor(mask))\n\n        images = torch.stack(images, 0)\n        masks = torch.stack(masks, 0)\n        n_obj = torch.LongTensor([n_obj + 1]) # +1 background\n        seq_name = self.sequence_names[index]\n\n        return images, masks, n_obj, seq_name\n"
  },
  {
    "path": "datasets/dataloader_video.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\nimport random\nimport torch\nimport math\nimport glob\n\nfrom PIL import Image\nimport numpy as np\n\nfrom .dataloader_base import DLBase\nimport datasets.daugm_video as tf\n\nclass DataVideo(DLBase):\n\n    def __init__(self, cfg, split, val=False):\n        super(DataVideo, self).__init__()\n\n        self.cfg = cfg\n        self.split = split\n        self.val = val\n\n        self.cfg_frame_gap = cfg.DATASET.VAL_FRAME_GAP if val else cfg.DATASET.FRAME_GAP\n\n        self._init_palette(cfg.TRAIN.BATCH_SIZE * cfg.MODEL.GRID_SIZE**2)\n\n        # train/val/test splits are pre-cut\n        split_fn = os.path.join(cfg.DATASET.ROOT, \"filelists\", self.split + \".txt\")\n        assert os.path.isfile(split_fn)\n\n        self.images = []\n\n        token = None # sequence token (new video when it changes)\n\n        subsequence = []\n        ignored = [0]\n        num_frames = [0]\n\n        def add_sequence():\n\n            if cfg.DATASET.VIDEO_LEN > len(subsequence):\n                # found a very short sequence\n                ignored[0] += 1\n            else:\n                # adding the subsequence\n                self.images.append(tuple(subsequence))\n                num_frames[0] += len(subsequence)\n\n            del subsequence[:]\n\n        with open(split_fn, \"r\") as lines:\n            for line in lines:\n                _line = line.strip(\"\\n\").split(' ')\n\n                assert len(_line) > 0, \"Expected at least one path\"\n                _image = _line[0]\n\n                # each sequence may have a different length\n                # do some book-keeping e.g. to ensure we have\n                # sequences long enough for subsequent sampling\n                _token = _image.split(\"/\")[-2] # parent directory\n\n                # sequence ID is in the filename\n                #_token = os.path.basename(_image).split(\"_\")[0] \n                if token != _token:\n                    if not token is None:\n                        add_sequence()\n                    token = _token\n    \n                # image 1\n                _image = os.path.join(cfg.DATASET.ROOT, _image.lstrip('/'))\n                #assert os.path.isfile(_image), '%s not found' % _image\n                subsequence.append(_image)\n\n        # update the last sequence\n        # returns the total amount of frames\n        add_sequence()\n        print(\"Dataloader: {}\".format(split), \" / Frame Gap: \", self.cfg_frame_gap)\n        print(\"Loaded {} sequences / {} ignored / {} frames\".format(len(self.images), \\\n                                                                    ignored[0], \\\n                                                                    num_frames[0]))\n\n        self._num_samples = num_frames[0]\n        self._init_augm(cfg)\n\n    def _init_augm(self, cfg):\n\n        # general (unguided) affine transformations\n        tfs_pre = [tf.CreateMask()]\n        self.tf_pre = tf.Compose(tfs_pre)\n\n        # photometric noise\n        tfs_affine = []\n\n        # guided affine transformations\n        tfs_augm = []\n\n        # 1.\n        # general affine transformations\n        #\n        tfs_pre.append(tf.MaskScaleSmallest(cfg.DATASET.SMALLEST_RANGE))\n        \n        if cfg.DATASET.RND_CROP:\n            tfs_pre.append(tf.MaskRandCrop(cfg.DATASET.CROP_SIZE, pad_if_needed=True))\n        else:\n            tfs_pre.append(tf.MaskCenterCrop(cfg.DATASET.CROP_SIZE))\n\n        if cfg.DATASET.RND_HFLIP:\n            tfs_pre.append(tf.MaskRandHFlip())\n\n        # 2.\n        # Guided affine transformation\n        #\n        if cfg.DATASET.GUIDED_HFLIP:\n            tfs_affine.append(tf.GuidedRandHFlip())\n\n        # this will add affine transformation\n        if cfg.DATASET.RND_ZOOM:\n            tfs_affine.append(tf.MaskRandScaleCrop(*cfg.DATASET.RND_ZOOM_RANGE))\n\n        self.tf_affine = tf.Compose(tfs_affine)\n        self.tf_affine2 = tf.Compose([tf.AffineIdentity()])\n\n        tfs_post = [tf.ToTensorMask(),\n                    tf.Normalize(mean=self.MEAN, std=self.STD),\n                    tf.ApplyMask(-1)]\n\n        # image to the teacher will have no noise\n        self.tf_post = tf.Compose(tfs_post)\n\n    def set_num_samples(self, n):\n        print(\"Re-setting # of samples: {:d} -> {:d}\".format(self._num_samples, n))\n        self._num_samples = n\n\n    def __len__(self):\n        return len(self.images) #self._num_samples\n\n    def denorm(self, image):\n\n        if image.dim() == 3:\n            assert image.dim() == 3, \"Expected image [CxHxW]\"\n            assert image.size(0) == 3, \"Expected RGB image [3xHxW]\"\n\n            for t, m, s in zip(image, self.MEAN, self.STD):\n                t.mul_(s).add_(m)\n        elif image.dim() == 4:\n            # batch mode\n            assert image.size(1) == 3, \"Expected RGB image [3xHxW]\"\n\n            for t, m, s in zip((0,1,2), self.MEAN, self.STD):\n                image[:, t, :, :].mul_(s).add_(m)\n\n        return image\n\n    def _get_affine(self, params):\n\n        N = len(params)\n\n        # construct affine operator\n        affine = torch.zeros(N, 2, 3)\n\n        aspect_ratio = float(self.cfg.DATASET.CROP_SIZE[0]) / \\\n                            float(self.cfg.DATASET.CROP_SIZE[1])\n\n        for i, (dy,dx,alpha,scale,flip) in enumerate(params):\n\n            # R inverse\n            sin = math.sin(alpha * math.pi / 180.)\n            cos = math.cos(alpha * math.pi / 180.)\n\n            # inverse, note how flipping is incorporated\n            affine[i,0,0], affine[i,0,1] = flip * cos, sin * aspect_ratio\n            affine[i,1,0], affine[i,1,1] = -sin / aspect_ratio, cos\n\n            # T inverse Rinv * t == R^T * t\n            affine[i,0,2] = -1. * (cos * dx + sin * dy)\n            affine[i,1,2] = -1. * (-sin * dx + cos * dy)\n\n            # T\n            affine[i,0,2] /= float(self.cfg.DATASET.CROP_SIZE[1] // 2)\n            affine[i,1,2] /= float(self.cfg.DATASET.CROP_SIZE[0] // 2)\n\n            # scaling\n            affine[i] *= scale\n\n        return affine\n\n    def _get_affine_inv(self, affine, params):\n\n        aspect_ratio = float(self.cfg.DATASET.CROP_SIZE[0]) / \\\n                            float(self.cfg.DATASET.CROP_SIZE[1])\n\n        affine_inv = affine.clone()\n        affine_inv[:,0,1] = affine[:,1,0] * aspect_ratio**2\n        affine_inv[:,1,0] = affine[:,0,1] / aspect_ratio**2\n        affine_inv[:,0,2] = -1 * (affine_inv[:,0,0] * affine[:,0,2] + affine_inv[:,0,1] * affine[:,1,2])\n        affine_inv[:,1,2] = -1 * (affine_inv[:,1,0] * affine[:,0,2] + affine_inv[:,1,1] * affine[:,1,2])\n\n        # scaling\n        affine_inv /= torch.Tensor(params)[:,3].view(-1,1,1)**2\n\n        return affine_inv\n\n    def __getitem__(self, index):\n\n        # searching for the video clip ID\n        sequence = self.images[index] # % len(self.images)]\n        seqlen = len(sequence)\n\n        assert self.cfg_frame_gap > 0, \"Frame gap should be positive\"\n        t_window = self.cfg_frame_gap * self.cfg.DATASET.VIDEO_LEN\n\n        # reduce sampling gap for short clips\n        t_window = min(seqlen, t_window)\n        frame_gap = t_window // self.cfg.DATASET.VIDEO_LEN\n\n        # strided slice\n        frame_ids = torch.arange(t_window)[::frame_gap]\n        frame_ids = frame_ids[:self.cfg.DATASET.VIDEO_LEN]\n        assert len(frame_ids) == self.cfg.DATASET.VIDEO_LEN\n\n        # selecting a random start\n        index_start = random.randint(0, seqlen - frame_ids[-1] - 1)\n        # permuting the frames in the batch\n        random_ids = torch.randperm(self.cfg.DATASET.VIDEO_LEN)\n        # adding the offset\n        frame_ids = frame_ids[random_ids] + index_start\n\n        # forward sequence\n        images = []\n        for frame_id in frame_ids:\n            fn = sequence[frame_id]\n            images.append(Image.open(fn).convert('RGB'))\n\n        # 1. general transforms\n        frames, valid = self.tf_pre(images)\n\n        # 1.1 creating two sequences in forward/backward order\n        frames1, valid1 = frames[:], valid[:]\n\n        # second copy\n        frames2 = [f.copy() for f in frames]\n        valid2 = [v.copy() for v in valid]\n\n        # 2. guided affine transforms\n        frames1, valid1, affine_params1 = self.tf_affine(frames1, valid1)\n        frames2, valid2, affine_params2 = self.tf_affine2(frames2, valid2)\n\n        # convert to tensor, zero out the values\n        frames1 = self.tf_post(frames1, valid1)\n        frames2 = self.tf_post(frames2, valid2)\n\n        # converting the affine transforms\n        aff_reg = self._get_affine(affine_params1)\n        aff_main = self._get_affine(affine_params2)\n\n        aff_reg_inv = self._get_affine_inv(aff_reg, affine_params1)\n\n        aff_reg = aff_main # identity affine2_inv\n        aff_main = aff_reg_inv\n\n        frames1 = torch.stack(frames1, 0)\n        frames2 = torch.stack(frames2, 0)\n\n        assert frames1.shape == frames2.shape\n\n        return frames2, frames1, aff_main, aff_reg\n\nclass DataVideoKinetics(DataVideo):\n\n    def __init__(self, cfg, split):\n        super(DataVideo, self).__init__()\n\n        self.cfg = cfg\n        self.split = split\n\n        self._init_palette(cfg.TRAIN.BATCH_SIZE * cfg.MODEL.GRID_SIZE**2)\n\n        # train/val/test splits are pre-cut\n        split_fn = os.path.join(cfg.DATASET.ROOT, \"filelists\", self.split + \".txt\")\n        assert os.path.isfile(split_fn)\n\n        self.videos = []\n\n        with open(split_fn, \"r\") as lines:\n            for line in lines:\n                _line = line.strip(\"\\n\").split(' ')\n                assert len(_line) > 0, \"Expected at least one path\"\n\n                _vid = _line[0]\n    \n                # image 1\n                _vid = os.path.join(cfg.DATASET.ROOT, _vid.lstrip('/'))\n                #assert os.path.isdir(_vid), \"{} does not exist\".format(_vid)\n                self.videos.append(_vid)\n\n        # update the last sequence\n        # returns the total amount of frames\n        print(\"DataloaderKinetics: {}\".format(split))\n        print(\"Loaded {} sequences\".format(len(self.videos)))\n        self._init_augm(cfg)\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, index):\n\n        C = self.cfg.DATASET\n        path = self.videos[index]\n\n        # filenames\n        fns = sorted(glob.glob(path + \"/*.jpg\"))\n        total_len = len(fns)\n\n        # temporal window to consider\n        temp_window = min(total_len, C.VIDEO_LEN * C.FRAME_GAP)\n        gap = temp_window / C.VIDEO_LEN\n        start_frame = random.randint(0, total_len - temp_window)\n\n        images = []\n        for idx in range(C.VIDEO_LEN):\n            frame_id = start_frame + int(idx * gap)\n            fn = fns[frame_id]\n            images.append(Image.open(fn).convert('RGB'))\n\n        # 1. general transforms\n        frames, valid = self.tf_pre(images)\n\n        # 1.1 creating two sequences in forward/backward order\n        frames1, valid1 = frames[:], valid[:]\n\n        # second copy\n        frames2 = [f.copy() for f in frames]\n        valid2 = [v.copy() for v in valid]\n\n        # 2. guided affine transforms\n        frames1, valid1, affine_params1 = self.tf_affine(frames1, valid1)\n        frames2, valid2, affine_params2 = self.tf_affine2(frames2, valid2)\n\n        # convert to tensor, zero out the values\n        frames1 = self.tf_post(frames1, valid1)\n        frames2 = self.tf_post(frames2, valid2)\n\n        # converting the affine transforms\n        affine1 = self._get_affine(affine_params1)\n        affine2 = self._get_affine(affine_params2)\n\n        affine1_inv = self._get_affine_inv(affine1, affine_params1)\n\n        affine1 = affine2 # identity affine2_inv\n        affine2 = affine1_inv\n\n        frames1 = torch.stack(frames1, 0)\n        frames2 = torch.stack(frames2, 0)\n\n        assert frames1.shape == frames2.shape\n\n        return frames2, frames1, affine2, affine1\n"
  },
  {
    "path": "datasets/daugm_video.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport random\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nimport torchvision.transforms as tf\nimport torchvision.transforms.functional as F\n\nclass Compose:\n    # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()])\n    def __init__(self, segtransform):\n        self.segtransform = segtransform\n\n    def __call__(self, args, *more_args):\n\n        # allow for intermediate representations\n        for t in self.segtransform:\n            result = t(args, *more_args)\n            args = result[0]\n            more_args = result[1:]\n\n        return result\n\nclass ToTensorMask:\n\n    def __toByteTensor(self, pic):\n        return torch.from_numpy(np.array(pic, np.int32, copy=False))\n\n    def __call__(self, images, masks):\n\n        new_masks = []\n        for i, (image, mask) in enumerate(zip(images, masks)):\n            images[i] = F.to_tensor(image)\n            new_masks.append(self.__toByteTensor(mask))\n\n        return images, new_masks\n\nclass CreateMask:\n    \"\"\"Create mask to hold invalid pixels\n    (e.g. from rotations or downscaling)\n    \"\"\"\n\n    def __call__(self, images):\n        \n        masks = []\n        for i, image in enumerate(images):\n            masks.append(Image.new(\"L\", image.size))\n\n        return images, masks\n\nclass Normalize:\n\n    # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std\n    def __init__(self, mean, std=None):\n\n        if std is None:\n            assert len(mean) > 0\n        else:\n            assert len(mean) == len(std)\n\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, images, masks):\n\n        for i, (image, mask) in enumerate(zip(images, masks)):\n\n            if self.std is None:\n                for t, m in zip(image, self.mean):\n                    t.sub_(m)\n            else:\n                for t, m, s in zip(image, self.mean, self.std):\n                    t.sub_(m).div_(s)\n\n        return images, masks\n\nclass ApplyMask:\n\n    def __init__(self, ignore_label):\n        self.ignore_label = ignore_label\n\n    def __call__(self, images, masks):\n\n        for i, (image, mask) in enumerate(zip(images, masks)):\n            mask = mask > 0.\n            images[i] *= (1. - mask.type_as(image))\n\n        return images\n\nclass GuidedRandHFlip:\n\n    def __call__(self, images, mask, affine=None):\n\n        if affine is None:\n            affine = [[0.,0.,0.,1.,1.] for _ in images]\n\n        if random.random() > 0.5:\n            for i, image in enumerate(images):\n                affine[i][4] *= -1\n                images[i] = F.hflip(image)\n\n        return images, mask, affine\n\nclass AffineIdentity(object):\n\n    def __call__(self, images, masks, affine=None):\n\n        if affine is None:\n            affine = [[0.,0.,0.,1.,1.] for _ in images]\n\n        return images, masks, affine\n\nclass MaskRandScaleCrop(object):\n\n    def __init__(self, scale_from, scale_to):\n        self.scale_from = scale_from\n        self.scale_to = scale_to\n        #assert scale_from >= 1., \"Zooming in is not supported yet\"\n\n    def get_scale(self):\n        return random.uniform(self.scale_from, self.scale_to)\n\n    def get_params(self, h, w, new_scale):\n        # generating random crop\n        # preserves aspect ratio\n        new_h = int(new_scale * h)\n        new_w = int(new_scale * w)\n\n        # generating \n        if new_scale <= 1.:\n            assert w >= new_w and h >= new_h, \"{} vs. {} | {} / {}\".format(w, new_w, h, new_h)\n            i = random.randint(0, h - new_h)\n            j = random.randint(0, w - new_w)\n        else:\n            assert w <= new_w and h <= new_h, \"{} vs. {} | {} / {}\".format(w, new_w, h, new_h)\n            i = random.randint(h - new_h, 0)\n            j = random.randint(w - new_w, 0)\n\n        return i, j, new_h, new_w\n\n    def __call__(self, images, masks, affine=None):\n\n        if affine is None:\n            affine = [[0.,0.,0.,1.,1.] for _ in images]\n\n        W, H = images[0].size\n\n        i2 = H / 2\n        j2 = W / 2\n\n        masks_new = []\n        \n        # one crop for all\n        s = self.get_scale()\n\n        ii, jj, h, w = self.get_params(H, W, s)\n\n        # displacement of the centre\n        dy = ii + h / 2 - i2\n        dx = jj + w / 2 - j2\n\n        for k, image in enumerate(images):\n\n            affine[k][0] = dy\n            affine[k][1] = dx\n            affine[k][3] = 1 / s # scale\n\n            if s <= 1.:\n                assert ii >= 0 and jj >= 0\n                # zooming in\n                image_crop = F.crop(image, ii, jj, h, w)\n                images[k] = image_crop.resize((W, H), Image.BILINEAR)\n\n                mask_crop = F.crop(masks[k], ii, jj, h, w)\n                masks_new.append(mask_crop.resize((W, H), Image.NEAREST))\n            else:\n                assert ii <= 0 and jj <= 0\n                # zooming out\n                pad_l = abs(jj)\n                pad_r = w - W - pad_l\n                pad_t = abs(ii)\n                pad_b = h - H - pad_t\n\n                image_pad = F.pad(image, (pad_l, pad_t, pad_r, pad_b))\n                images[k] = image_pad.resize((W, H), Image.BILINEAR)\n\n                mask_pad = F.pad(masks[k], (pad_l, pad_t, pad_r, pad_b), 1)\n                masks_new.append(mask_pad.resize((W, H), Image.NEAREST))\n\n        return images, masks, affine\n\nclass MaskScaleSmallest(object):\n\n    def __init__(self, smallest_range):\n        self.size = smallest_range\n\n    def __call__(self, images, masks):\n        assert len(images) > 0, \"Non-empty array expected\"\n\n        new_size = self.size[0] + int((self.size[1] - self.size[0]) * random.random())\n\n        w, h = images[0].size\n        aspect = w / h\n\n        if aspect > 1:\n            new_h = new_size\n            new_w = int(new_size * aspect)\n        else:\n            new_w = new_size\n            new_h = int(new_size / aspect)\n\n        new_size = (new_w, new_h)\n\n        for i, (image, mask) in enumerate(zip(images, masks)):\n            assert image.size == mask.size\n            assert image.size == (w, h)\n\n            images[i] = image.resize(new_size, Image.BILINEAR)\n            masks[i] = mask.resize(new_size, Image.NEAREST)\n\n        return images, masks\n\nclass MaskRandCrop:\n\n    def __init__(self, size, pad_if_needed=False):\n        self.size = size # (h, w)\n        self.pad_if_needed = pad_if_needed\n\n    def __pad(self, img, padding_mode='constant', fill=0):\n\n        # pad the width if needed\n        pad_width = self.size[1] - img.size[0]\n        pad_height = self.size[0] - img.size[1]\n        if self.pad_if_needed and (pad_width > 0 or pad_height > 0):\n            pad_l = max(0, pad_width // 2)\n            pad_r = max(0, pad_width - pad_l)\n            pad_t = max(0, pad_height // 2)\n            pad_b = max(0, pad_height - pad_t)\n            img = F.pad(img, (pad_l, pad_t, pad_r, pad_b), fill, padding_mode)\n\n        return img\n\n    def __call__(self, images, masks):\n\n        for i, (image, mask) in enumerate(zip(images, masks)):\n            images[i] = self.__pad(image)\n            masks[i] = self.__pad(mask, fill=1)\n\n        i, j, h, w = tf.RandomCrop.get_params(images[0], self.size)\n\n        for k, (image, mask) in enumerate(zip(images, masks)):\n            images[k] = F.crop(image, i, j, h, w)\n            masks[k] = F.crop(mask, i, j, h, w)\n\n        return images, masks\n\nclass MaskCenterCrop:\n\n    def __init__(self, size):\n        self.size = size # (h, w)\n\n    def __call__(self, images, masks):\n\n        for i, (image, mask) in enumerate(zip(images, masks)):\n            images[i] = F.center_crop(image, self.size)\n            masks[i] = F.center_crop(mask, self.size)\n\n        return images, masks\n\nclass MaskRandHFlip:\n\n    def __call__(self, images, masks):\n\n        if random.random() > 0.5:\n\n            for i, (image, mask) in enumerate(zip(images, masks)):\n                images[i] = F.hflip(image)\n                masks[i] = F.hflip(mask)\n\n        return images, masks\n"
  },
  {
    "path": "infer_vos.py",
    "content": "\"\"\"\nSingle-scale inference\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\nimport sys\nimport numpy as np\nimport imageio\nimport time\n\nimport torch.multiprocessing as mp\nfrom tqdm import tqdm\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom opts import get_arguments\nfrom core.config import cfg, cfg_from_file, cfg_from_list\nfrom models import get_model\nfrom utils.timer import Timer\nfrom utils.sys_tools import check_dir\nfrom utils.palette_davis import palette as davis_palette\n\nfrom torch.utils.data import DataLoader\nfrom datasets.dataloader_infer import DataSeg\n\nfrom labelprop.common import LabelPropVOS_CRW\n\n# deterministic inference\nfrom torch.backends import cudnn\n\ncudnn.enabled = True\ncudnn.benchmark = False\ncudnn.deterministic = True\n\nVERBOSE = True\n\ndef mask2rgb(mask, palette):\n    mask_rgb = palette(mask)\n    mask_rgb = mask_rgb[:,:,:3]\n    return mask_rgb\n\ndef mask_overlay(mask, image, palette):\n    \"\"\"Creates an overlayed mask visualisation\"\"\"\n    mask_rgb = mask2rgb(mask, palette)\n    return 0.3 * image + 0.7 * mask_rgb\n\nclass ResultWriter:\n    \n    def __init__(self, key, palette, out_path):\n        self.key = key\n        self.palette = palette\n        self.out_path = out_path\n        self.verbose = VERBOSE\n\n    def save(self, frames, masks_pred, masks_conf, masks_gt, flags, fn, seq_name):\n\n        subdir_vos = os.path.join(self.out_path, \"{}_vos\".format(self.key))\n        check_dir(subdir_vos, seq_name)\n\n        subdir_vis = os.path.join(self.out_path, \"{}_vis\".format(self.key))\n        check_dir(subdir_vis, seq_name)\n\n        for frame_id, mask in enumerate(masks_pred.split(1, 0)):\n\n            mask = mask[0].numpy().astype(np.uint8)\n            filepath = os.path.join(subdir_vos, seq_name, \"{}.png\".format(fn[frame_id][0]))\n\n            # saving only every 5th frame\n            if flags[frame_id] != 0:\n                imageio.imwrite(filepath, mask)\n\n            if self.verbose:\n                frame = frames[frame_id].numpy()\n                #mask_gt = masks_gt[frame_id].numpy().astype(np.uint8)\n                #masks = np.concatenate([mask, mask_gt], 1)\n                #frame = np.concatenate([frame, frame], 2)\n                frame = np.transpose(frame, [1,2,0])\n\n                overlay = mask_overlay(mask, frame, self.palette)\n                filepath = os.path.join(subdir_vis, seq_name, \"{}.png\".format(fn[frame_id][0]))\n                imageio.imwrite(filepath, (overlay * 255.).astype(np.uint8))\n\n\ndef convert_dict(state_dict):\n    new_dict = {}\n    for k,v in state_dict.items():\n        new_key = k.replace(\"module.\", \"\")\n        new_dict[new_key] = v\n    return new_dict\n\ndef mask2tensor(mask, idx, num_classes=cfg.DATASET.NUM_CLASSES):\n    h,w = mask.shape\n    mask_t = torch.zeros(1,num_classes,h,w)\n    mask_t[0, idx] = mask\n    return mask_t\n\ndef configure_tracks(masks_gt, tracks, num_objects):\n    \"\"\"Selecting masks for initialisation\n\n    Args:\n        masks_gt: [T,H,W]\n        tracks: [T,2]\n\n    \"\"\"\n    init_masks = {}\n\n    # we always have first mask\n    # if there are no instances, it will be simply zero\n    H,W = masks_gt[0].shape[-2:]\n    init_masks[0] = torch.zeros(1, cfg.DATASET.NUM_CLASSES, H, W)\n\n    for oid in range(cfg.DATASET.NUM_CLASSES):\n\n        t = tracks[oid].item()\n        if not t in init_masks:\n            init_masks[t] = mask2tensor(masks_gt[oid], oid)\n        else:\n            init_masks[t] += mask2tensor(masks_gt[oid], oid)\n\n    return init_masks\n\ndef make_onehot(mask, HW):\n    # convert mask tensor with probabilities to a one-hot tensor\n    b,c,h,w = mask.shape\n\n    mask_up = F.interpolate(mask, HW, mode=\"bilinear\", align_corners=True)\n    one_hot = torch.zeros_like(mask_up)\n    one_hot.scatter_(1, mask_up.argmax(1, keepdim=True), 1)\n    one_hot = F.interpolate(one_hot, (h,w), mode=\"bilinear\", align_corners=True)\n\n    return one_hot\n\ndef scale_smallest(frame, a):\n    H,W = frame.shape[-2:]\n    s = a / min(H, W)\n    h, w = int(s * H), int(s * W)\n    return F.interpolate(frame, (h, w), mode=\"bilinear\", align_corners=True)\n\ndef valid_mask(mask):\n    \"\"\"From a tensor [1,C,h,w]\n    create [1,C,1,1] 0/1 mask saying which IDs are present\"\"\"\n    B,C,h,w = mask.shape\n    valid = mask.flatten(2,3).sum(-1) > 0\n    valid = valid.type_as(mask).view(B,C,1,1)\n    return valid\n\ndef merge_mask_ids(masks, key0):\n    merged_mask = torch.zeros_like(masks[key0])\n    for tt, mask in masks.items():\n        merged_mask[:,1:] += mask[:,1:]\n\n    probs, ids = merged_mask.max(1, keepdim=True)\n    merged_mask.zero_()\n    merged_mask.scatter(1, ids, probs)\n    merged_mask[:, :1] = 1 - probs\n    return merged_mask\n\ndef step_seg(cfg, net, labelprop, frames, mask_init):\n\n    # dense tracking: start from the 1st frame\n    # keep track of new objects\n\n    T = frames.shape[0]\n    frames = frames.cuda()\n\n    # scale smallest\n    if cfg.TEST.INPUT_SIZE > 0:\n        frames = scale_smallest(frames, cfg.TEST.INPUT_SIZE)\n\n    fetch = {\"res3\": lambda x: x[0], \\\n             \"res4\": lambda x: x[1], \\\n             \"key\": lambda x: x[2]}\n\n    for t in mask_init.keys():\n        mask_init[t] = mask_init[t].cuda()\n\n    H,W = mask_init[0].shape[-2:]\n\n    scale_as = lambda x, y: F.interpolate(x, y.shape[-2:], mode=\"bilinear\", align_corners=True)\n    scale = lambda x, hw: F.interpolate(x, hw, mode=\"bilinear\", align_corners=True)\n\n    # context (cxt) will maintain \n    # the reference frame\n    ref_embd = {}   # context embeddings\n    ref_masks = {}\n    ref_valid = {}\n\n    # we will also keep a bunch\n    # of previous frames\n    prev_embd = None\n    prev_masks = None\n\n    all_masks = []\n    all_masks_conf = []\n\n    def add_result(mask):\n        mask_up = scale(mask, (H, W))\n        nxt_masks_conf, nxt_masks_id = mask_up.max(1)\n        all_masks.append(nxt_masks_id.cpu())\n        all_masks_conf.append(nxt_masks_conf.cpu())\n\n    # initialising\n    embd0 = net(frames[:1], embd_only=True, norm=True)\n    embd0 = fetch[cfg.TEST.KEY](embd0)\n    mask0 = scale_as(mask_init[0], embd0)\n    add_result(mask0)\n\n    ref_embd[0] = {0: embd0}\n    ref_masks[0] = {0: mask0}\n    ref_valid[0] = valid_mask(mask0) # [x,c,1,1]\n\n    # add this to the reference context\n    # if there are objects\n    ref_index = []\n    if mask_init[0].sum() > 0:\n        ref_index = [0]\n\n    print(\">\", end='')\n    for t in range(1, T):\n        print(\".\", end='')\n        sys.stdout.flush()\n\n        # next frame\n        frames_batch = frames[t:t+1]\n\n        # source forward pass\n        nxt_embds = net(frames_batch, embd_only=True, norm=True)\n\n        # fetching the feature\n        nxt_embd = fetch[cfg.TEST.KEY](nxt_embds)\n\n        ref_t = [0] if len(ref_index) == 0 else ref_index\n\n        # for each reference mask\n        # we will create own context, then\n        # propagate the labels and merge the result\n        nxt_masks = {}\n        for t0 in ref_t:\n            cxt_index = labelprop.context_index(t0, t)\n            cxt_embd = [ref_embd[t0][j] for j in cxt_index]\n            cxt_masks = [ref_masks[t0][j] for j in cxt_index]\n            nxt_masks[t0] = labelprop.predict(cxt_embd, cxt_masks, nxt_embd, cxt_index, t)\n\n        # merging all the masks\n        nxt_mask = sum([ref_valid[tt] * nxt_masks[tt] for tt in nxt_masks.keys()])\n        #nxt_mask = merge_mask_ids(nxt_masks, ref_t[0])\n\n        if t in mask_init: # not t >= 0\n            print(\"Adding GT mask t = \", t)\n            # adding the initial mask if just appeared\n            mask_init_dn = scale_as(mask_init[t], nxt_embd)\n            mask_init_dn_s = mask_init_dn.sum(1, keepdim=True)\n            nxt_mask = (1 - mask_init_dn_s) * nxt_mask + mask_init_dn_s * mask_init_dn\n\n            # adding to context\n            ref_embd[t] = {}\n            ref_masks[t] = {}\n            ref_valid[t] = valid_mask(mask_init[t])\n            ref_index.append(t)\n\n        add_result(nxt_mask)\n        ref_t = [0] if len(ref_index) == 0 else ref_index\n\n        #\n        # updating the context\n        # two parts: for every initial index, keep first N and last M frames\n        for t0 in ref_t:\n            ref_embd[t0][t] = nxt_embd.clone()\n            ref_masks[t0][t] = nxt_mask.clone()\n\n            index_short = labelprop.context_long(t0, t)\n\n            tsteps = list(ref_embd[t0].keys())\n            for tt in tsteps:\n                if t - tt > cfg.TEST.CXT_SIZE and not tt in index_short:\n                    del ref_embd[t0][tt]\n                    del ref_masks[t0][tt]\n\n    print('<')\n    masks_pred = torch.cat(all_masks, 0)\n    masks_pred_conf = torch.cat(all_masks_conf, 0)\n\n    return masks_pred, masks_pred_conf\n\n\nif __name__ == '__main__':\n\n    # loading the model\n    args = get_arguments(sys.argv[1:])\n\n    # reading the config\n    cfg_from_file(args.cfg_file)\n    if args.set_cfgs is not None:\n        cfg_from_list(args.set_cfgs)\n\n    # initialising the dirs\n    check_dir(args.mask_output_dir, \"{}_vis\".format(cfg.TEST.KEY))\n    check_dir(args.mask_output_dir, \"{}_vos\".format(cfg.TEST.KEY))\n\n    # Loading the model\n    model = get_model(cfg, remove_layers=cfg.MODEL.REMOVE_LAYERS)\n\n    labelprop = LabelPropVOS_CRW(cfg)\n\n    if not os.path.isfile(args.resume):\n        print(\"[W]: \", \"Snapshot not found: {}\".format(args.resume))\n        print(\"[W]: Using a random model\")\n    else:\n        state_dict = convert_dict(torch.load(args.resume)[\"model\"])\n        try:\n            model.load_state_dict(state_dict)\n        except Exception as e:\n            print(\"Error while loading the snapshot:\\n\", str(e))\n            print(\"Resuming...\")\n            model.load_state_dict(state_dict, strict=False)\n\n    for p in model.parameters():\n        p.requires_grad = False\n\n    # setting the evaluation mode\n    model.eval()\n    model = model.cuda()\n    dataset = DataSeg(cfg, args.infer_list)\n\n    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, \\\n                                    drop_last=False) #, num_workers=args.workers)\n    palette = dataloader.dataset.get_palette()\n\n    timer = Timer()\n    N = len(dataloader)\n\n    pool = mp.Pool(processes=args.workers)\n    writer = ResultWriter(cfg.TEST.KEY, davis_palette, args.mask_output_dir)\n\n    for iter, batch in enumerate(dataloader):\n        frames_orig, frames, masks_gt, tracks, num_ids, fns, flags, seq_name = batch\n\n        print(\"Sequence {:02d} | {}\".format(iter, seq_name[0]))\n\n        masks_gt = masks_gt.flatten(0,1)\n        frames_orig = frames_orig.flatten(0,1)\n        frames = frames.flatten(0,1)\n        tracks = tracks.flatten(0,1)\n        flags = flags.flatten(0,1)\n\n        init_masks = configure_tracks(masks_gt, tracks, num_ids[0])\n        assert 0 in init_masks, \"initial frame has no instances\"\n\n        with torch.no_grad():\n            masks_pred, masks_conf = step_seg(cfg, model, labelprop, frames, init_masks)\n\n        frames_orig = dataset.denorm(frames_orig)\n\n        pool.apply_async(writer.save, args=(frames_orig, masks_pred.cpu(), masks_conf.cpu(), masks_gt.cpu(), flags, fns, seq_name[0]))\n\n    timer.stage(\"Inference completed\")\n    pool.close()\n    pool.join()\n"
  },
  {
    "path": "labelprop/common.py",
    "content": "\"\"\"\nBased on the inference routines from Jabri et al., (2020)\nCredit: https://github.com/ajabri/videowalk.git\nLicense: MIT\n\"\"\"\n\nimport sys\nimport torch\nfrom labelprop.crw import MaskedAttention\nfrom labelprop.crw import mem_efficient_batched_affinity as batched_affinity\n\nclass LabelPropVOS(object):\n\n    def context_long(self):\n        \"\"\"Returns indices of the timesteps\n        for long-term memory\n        \"\"\"\n        raise NotImplementedError()\n\n    def context_short(self, t):\n        \"\"\"\n        Args:\n            t: current timestep\n        Returns:\n            list: indices of timesteps\n                  for the context\n        \"\"\"\n        raise NotImplementedError()\n    \n    def predict(self, feats, masks, curr_feat):\n        \"\"\"\n        Args:\n            feats [C,K,h,w]: context features\n            masks [C,M,h,w]: context masks\n            curr_feat [1,K,h,w]: current frame features\n        Returns:\n            mask [1,M,h,w]: current frame mask\n        \"\"\"\n        raise NotImplementedError()\n\n\nclass LabelPropVOS_CRW(LabelPropVOS):\n\n    def __init__(self, cfg):\n        self.cxt_size = cfg.TEST.CXT_SIZE\n        self.radius = cfg.TEST.RADIUS\n        self.temperature = cfg.TEST.TEMP\n        self.topk = cfg.TEST.KNN\n        self.mask = None\n        self.mask_hw = None\n\n    def context_long(self, t0, t):\n        return [t0]\n\n    def context_short(self, t0, t):\n        to_t = t\n        from_t = to_t - self.cxt_size\n        timesteps = [max(tt, t0) for tt in range(from_t, to_t)]\n        return timesteps\n\n    def context_index(self, t0, t):\n        index_short = self.context_short(t0, t)\n        index_long = self.context_long(t0, t)\n        cxt_index = index_long + index_short\n        return cxt_index\n\n    def predict(self, feats, masks, curr_feat, ref_index=None, t=None):\n        \"\"\"\n        Args:\n            feats: list of C [1,K,h,w] context features\n            masks: list of C [1,M,h,w] context masks\n            curr_feat: [1,K,h,w] current frame features\n            ref_index: C indices of context frames\n            t: current frame time step\n        Returns:\n            mask [1,M,h,w]: current frame mask\n        \"\"\"\n        dev = curr_feat.device\n        h, w = curr_feat.shape[-2:]\n\n        # [BC+N,M,h,w] -> [BC+N,h,w,M]\n        ctx_lbls = torch.cat(masks, 0).permute([0,2,3,1])\n\n        # keys [1,K,C,h,w]: context features\n        keys = torch.stack(feats, 2)[:, :, None]\n        # query: [1,K,1,h,w]: reference feature\n        query = curr_feat[:, :, None]\n\n        if self.mask is None or self.mask_hw != (h, w):\n            # Make spatial radius mask TODO use torch.sparse\n            restrict = MaskedAttention(self.radius, flat=False)\n            D = restrict.mask(h, w)[None]\n            D = D.flatten(-4, -3).flatten(-2)\n            D[D==0] = -1e10; D[D==1] = 0\n            self.mask = D.to(dev)\n            self.mask_hw = (h, w)\n\n        # Flatten source frame features to make context feature set\n        keys, query = keys.flatten(-2), query.flatten(-2)\n\n        long_mem = [0]\n        Ws, Is = batched_affinity(query, keys, self.mask,  \\\n                self.temperature, self.topk, long_mem)\n\n        # Soft labels of source nodes\n        ctx_lbls = ctx_lbls.flatten(0, 2).transpose(0, 1)\n\n        # Weighted sum of top-k neighbours (Is is index, Ws is weight) \n        pred = (ctx_lbls[:, Is[0].to(dev)] * Ws[0][None].to(dev)).sum(1)\n        pred = pred.view(-1, h, w)\n        pred = pred.permute(1,2,0)\n\n        # Adding Predictions            \n        pred = pred.permute([2,0,1])[None, ...]\n\n        return pred\n"
  },
  {
    "path": "labelprop/crw.py",
    "content": "\"\"\"\nInference routines from Jabri et al., (2020)\nCredit: https://github.com/ajabri/videowalk.git\nLicense: MIT\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\n\nclass CRW(object):\n\n    \"\"\"Propagation algorithm\"\"\"\n    def __init__(self, cfg):\n        self.n_context = cfg.CXT_SIZE\n        self.radius = cfg.RADIUS\n        self.temperature = cfg.TEMP\n        self.topk = cfg.KNN\n\n        print(\"Inference Opts:\")\n        print(\"Context size: {}\".format(self.n_context))\n        print(\"      Radius: {}\".format(self.radius))\n        print(\"        Temp: {}\".format(self.temperature))\n        print(\"        TopK: {}\".format(self.topk))\n\n        # always keeping the first frame\n        # TODO: move to cfg\n        self.long_mem = [0]\n        # for bwd-compatibility\n        self.norm_mask = False\n\n        self.mask = None\n        self.mask_hw = None\n\n    def _prep_context(self, feats, lbls, hw):\n        \"\"\"Adjust for context\n        Args:\n            lbls: [N,M,H,W]\n        \"\"\"\n        lbls = F.interpolate(lbls, hw, mode=\"bilinear\", align_corners=True)\n\n        ref = lbls[:1].expand(self.n_context,-1,-1,-1)\n        lbls = torch.cat([ref, lbls], 0)\n\n        fref = feats[:1].expand(self.n_context,-1,-1,-1)\n        fref = torch.cat([fref, feats], 0)\n\n        return fref, lbls\n\n    def forward(self, feats, lbls):\n        \"\"\"Propagate features\n\n        Args:\n            feats: [N,K,h,w]\n            lbls: [N,M,H,W]\n        \"\"\"\n        N,K,h,w = feats.shape\n        M,H,W = lbls.shape[-3:]\n        feats, lbls = self._prep_context(feats, lbls, (h,w))\n\n        # [N,K,h,w] -> [1,K,N,h,w]\n        feats = feats.permute([1,0,2,3])\n        # singleton for compatibility\n        feats = feats[None,...]\n\n        # [BC+N,M,h,w] -> [BC+N,h,w,M]\n        lbls = lbls.permute([0,2,3,1])\n\n        n_context = self.n_context\n\n        torch.cuda.empty_cache()\n\n        # Prepare source (keys) and target (query) frame features\n        key_indices = context_index_bank(n_context, self.long_mem, N)\n        key_indices = torch.cat(key_indices, dim=-1)\n        keys = feats[:, :, key_indices]\n        query = feats[:, :, n_context:]\n\n        # Make spatial radius mask TODO use torch.sparse\n        if self.mask is None or self.mask_hw != (h, w):\n            restrict = MaskedAttention(self.radius, flat=False)\n            D = restrict.mask(h, w)[None]\n            D = D.flatten(-4, -3).flatten(-2)\n            D[D==0] = -1e10; D[D==1] = 0\n            self.mask = D.cuda()\n            self.mask_hw = (h, w)\n\n        # Flatten source frame features to make context feature set\n        keys, query = keys.flatten(-2), query.flatten(-2)\n\n        Ws, Is = mem_efficient_batched_affinity(query, keys, self.mask,  \\\n                self.temperature, self.topk, self.long_mem)\n        \n        ##################################################################\n        # Propagate Labels and Save Predictions\n        ###################################################################\n\n        masks_idx = torch.LongTensor(N,H,W)\n        masks_prob = torch.FloatTensor(N,H,W)\n\n        for t in range(key_indices.shape[0]):\n            # Soft labels of source nodes\n            ctx_lbls = lbls[key_indices[t]].cuda()\n            ctx_lbls = ctx_lbls.flatten(0, 2).transpose(0, 1)\n\n            # Weighted sum of top-k neighbours (Is is index, Ws is weight) \n            pred = (ctx_lbls[:, Is[t]] * Ws[t][None].cuda()).sum(1)\n            pred = pred.view(-1, h, w)\n            pred = pred.permute(1,2,0)\n\n            if t > 0:\n                lbls[t + n_context] = pred\n            else:\n                pred = lbls[0]\n                lbls[t + n_context] = pred\n\n            if self.norm_mask:\n                pred[:, :, :] -= pred.min(-1)[0][:, :, None]\n                pred[:, :, :] /= pred.max(-1)[0][:, :, None]\n\n            # Adding Predictions            \n            pred_ = pred.permute([2,0,1])[None, ...]\n            pred_up = F.interpolate(pred_, (H,W), mode=\"bilinear\", align_corners=True)\n            pred_up = pred_up[0].cpu()\n            masks_idx[t] = pred_up.argmax(0)\n            masks_prob[t] = pred_up[1]\n\n        out = {}\n        out[\"masks_pred_idx\"] = masks_idx\n        out[\"masks_pred_conf\"] = masks_prob\n        return out\n\ndef context_index_bank(n_context, long_mem, N):\n    '''\n    Construct bank of source frames indices, for each target frame\n    '''\n    ll = []   # \"long term\" context (i.e. first frame)\n    for t in long_mem:\n        assert 0 <= t < N, 'context frame out of bounds'\n        idx = torch.zeros(N, 1).long()\n        if t > 0:\n            idx += t + (n_context+1)\n            idx[:n_context+t+1] = 0\n        ll.append(idx)\n    # \"short\" context    \n    ss = [(torch.arange(n_context)[None].repeat(N, 1) +  torch.arange(N)[:, None])[:, :]]\n\n    return ll + ss\n\ndef batched_affinity(query, keys, mask, temperature, topk, long_mem, device):\n    '''\n    Mini-batched computation of affinity, for memory efficiency\n    (less aggressively mini-batched)\n    '''\n\n    A = torch.einsum('ijklm,ijkn->iklmn', keys, query)\n\n    # Mask\n    A[0, :, len(long_mem):] += mask.to(device)\n\n    _, N, T, h1w1, hw = A.shape\n    A = A.view(N, T*h1w1, hw)\n    A /= temperature\n\n    weights, ids = torch.topk(A, topk, dim=-2)\n    weights = F.softmax(weights, dim=-2)\n\n    Ws = [w for w in weights]\n    Is = [ii for ii in ids]\n\n    return Ws, Is\n\ndef mem_efficient_batched_affinity(query, keys, mask, temperature, topk, long_mem):\n    '''\n    Mini-batched computation of affinity, for memory efficiency\n    '''\n    bsize, pbsize = 2, 100\n    Ws, Is = [], []\n\n    for b in range(0, keys.shape[2], bsize):\n        _k, _q = keys[:, :, b:b+bsize].cuda(), query[:, :, b:b+bsize].cuda()\n        w_s, i_s = [], []\n\n        for pb in range(0, _k.shape[-1], pbsize):\n            A = torch.einsum('ijklm,ijkn->iklmn', _k, _q[..., pb:pb+pbsize]) \n            A[0, :, len(long_mem):] += mask[..., pb:pb+pbsize]\n\n            _, N, T, h1w1, hw = A.shape\n            A = A.view(N, T*h1w1, hw)\n            A /= temperature\n\n            weights, ids = torch.topk(A, topk, dim=-2)\n            weights = F.softmax(weights, dim=-2)\n\n            w_s.append(weights)\n            i_s.append(ids)\n\n        weights = torch.cat(w_s, dim=-1)\n        ids = torch.cat(i_s, dim=-1)\n        Ws += [w for w in weights]\n        Is += [ii for ii in ids]\n\n    return Ws, Is\n\n\nclass MaskedAttention(nn.Module):\n    '''\n    A module that implements masked attention based on spatial locality \n    TODO implement in a more efficient way (torch sparse or correlation filter)\n    '''\n    def __init__(self, radius, flat=True):\n        super(MaskedAttention, self).__init__()\n        self.radius = radius\n        self.flat = flat\n        self.masks = {}\n        self.index = {}\n\n    def mask(self, H, W):\n        if not ('%s-%s' %(H,W) in self.masks):\n            self.make(H, W)\n        return self.masks['%s-%s' %(H,W)]\n\n    def index(self, H, W):\n        if not ('%s-%s' %(H,W) in self.index):\n            self.make_index(H, W)\n        return self.index['%s-%s' %(H,W)]\n\n    def make(self, H, W):\n        if self.flat:\n            H = int(H**0.5)\n            W = int(W**0.5)\n        \n        gx, gy = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))\n        D = ( (gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2 ).float() ** 0.5\n        D = (D < self.radius)[None].float()\n\n        if self.flat:\n            D = self.flatten(D)\n        self.masks['%s-%s' %(H,W)] = D\n\n        return D\n\n    def flatten(self, D):\n        return torch.flatten(torch.flatten(D, 1, 2), -2, -1)\n\n    def make_index(self, H, W, pad=False):\n        mask = self.mask(H, W).view(1, -1).byte()\n        idx = torch.arange(0, mask.numel())[mask[0]][None]\n\n        self.index['%s-%s' %(H,W)] = idx\n\n        return idx\n        \n    def forward(self, x):\n        H, W = x.shape[-2:]\n        sid = '%s-%s' % (H,W)\n        if sid not in self.masks:\n            self.masks[sid] = self.make(H, W).to(x.device)\n        mask = self.masks[sid]\n\n        return x * mask[0]\n"
  },
  {
    "path": "launch/infer_vos.sh",
    "content": "#!/bin/bash\n\n#\n# Arguments\n#\n\n# suffix for the output directory (see below)\nVER=v01\n\n# defines the path to the snapshot (see below)\nEXP=ours\nRUN_ID=final\n\n# SNAPSHOT name. Note .pth will be attached\n# See README.md to download these snapshots\n\nSNAPSHOT=ytvos_e060_res4       # YouTube-VOS\n#SNAPSHOT=trackingnet_e088_res4 # TrackingNet\n#SNAPSHOT=oxuva_e430_res4       # OxUvA\n#SNAPSHOT=kinetics_e026_res4    # Kinetics\n\n# codename of the final output layer [res3|res4|key]\nKEY=res4\n\n#\n# Changing the following is not necessary\n#\n\nDS=$1\ncase $DS in\nytvos)\n  echo \"Test dataset: YouTube-VOS 2018 (val)\"\n  FILELIST=filelists/val_ytvos2018_test\n  ;;\ndavis)\n  echo \"Test dataset: DAVIS-2017 (val)\"\n  FILELIST=filelists/val_davis2017_test\n  ;;\n*)\n  echo \"Dataset '$DS' not recognised. Should be one of [ytvos|davis].\"\n  exit 1\n  ;;\nesac\n\n# The config file is irrelevant\n# since the inference parameters are\n# always the same\nCONFIG=configs/ytvos.yaml\nOUTPUT_DIR=./results\n\nEXTRA=\"$EXTRA --seed 0 --set TEST.KEY $KEY\"\nSAVE_ID=${RUN_ID}_${SNAPSHOT}_${VER}\n\nSNAPSHOT_PATH=snapshots/${EXP}/${RUN_ID}/${SNAPSHOT}.pth\nif [ ! -f $SNAPSHOT_PATH ]; then\n  echo \"Snapshot $SNAPSHOT_PATH NOT found.\"\n  exit 1;\nfi\n\n\n#\n# Code goes here\n#\nLISTNAME=`basename $FILELIST .txt`\nSAVE_DIR=$OUTPUT_DIR/$EXP/$SAVE_ID/$LISTNAME\nLOG_FILE=$OUTPUT_DIR/$EXP/$SAVE_ID/${LISTNAME}_${KEY}.log\n\nNUM_THREADS=12\nexport OMP_NUM_THREADS=$NUM_THREADS\nexport MKL_NUM_THREADS=$NUM_THREADS\n\n\nCMD=\"python infer_vos.py   --cfg $CONFIG \\\n                           --exp $EXP \\\n                           --run $RUN_ID \\\n                           --resume $SNAPSHOT_PATH \\\n                           --infer-list $FILELIST \\\n                           --mask-output-dir $SAVE_DIR \\\n                           $EXTRA\"\n\nif [ ! -d $SAVE_DIR ]; then\n  echo \"Creating directory: $SAVE_DIR\"\n  mkdir -p $SAVE_DIR\nelse\n  echo \"Saving to: $SAVE_DIR\"\nfi\n\ngit rev-parse HEAD > ${SAVE_DIR}.head\ngit diff > ${SAVE_DIR}.diff\necho $CMD > ${SAVE_DIR}.cmd\n\necho $CMD\nnohup $CMD > $LOG_FILE 2>&1 &\n\nsleep 1\ntail -f $LOG_FILE\n"
  },
  {
    "path": "launch/train.sh",
    "content": "#!/bin/bash\n\n# Set the following variables\n# The tensorboard logging will be creating in logs/<EXP>/<EXP_ID>\n# The snapshots will be saved in snapshots/<EXP>/<EXP_ID>\nEXP=v01_vos\nEXP_ID=v01_00_base\n\n#\n# No change are necessary starting here\n#\n\nDS=$1\nSEED=32\n\ncase $DS in\noxuva)\n  echo \"Train dataset: OxUvA\"\n  TASK=\"OxUvA_all\"\n  CFG=configs/oxuva.yaml\n  EXP_ID=\"OX_${EXP_ID}\"\n  ;;\nytvos)\n  echo \"Train dataset: YouTube-VOS\"\n  TASK=\"YTVOS\"\n  CFG=configs/ytvos.yaml\n  EXP_ID=\"YT_${EXP_ID}\"\n  ;;\ntrack)\n  echo \"Train dataset: TrackingNet\"\n  TASK=\"TrackingNet\"\n  CFG=configs/tracknet.yaml\n  EXP_ID=\"TN_${EXP_ID}\"\n  ;;\nkinetics)\n  echo \"Train dataset: Kinetics-400\"\n  CFG=configs/kinetics.yaml\n  EXP_ID=\"KT_${EXP_ID}\"\n  ;;\n*)\n  echo \"Dataset '$DS' not recognised. Should be one of [oxuva|ytvos|track|kinetics].\"\n  exit 1\n  ;;\nesac\n\n\nCURR_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" >/dev/null 2>&1 && pwd )\"\nsource $CURR_DIR/utils.bash\n\nCMD=\"python train.py --cfg $CFG --exp $EXP --run $EXP_ID --seed $SEED\"\nLOG_DIR=logs/${EXP}\nLOG_FILE=$LOG_DIR/${EXP_ID}.log\necho \"LOG: $LOG_FILE\"\n\ncheck_rundir $LOG_DIR $EXP_ID\n\nNUM_THREADS=12\n\nexport OMP_NUM_THREADS=$NUM_THREADS\nexport MKL_NUM_THREADS=$NUM_THREADS\n\necho $CMD\n\nCMD_FILE=$LOG_DIR/${EXP_ID}.cmd\necho $CMD > $CMD_FILE\n\ngit rev-parse HEAD > $LOG_DIR/${EXP_ID}.head\ngit diff > $LOG_DIR/${EXP_ID}.diff\n\nnohup $CMD > $LOG_FILE 2>&1 &\nsleep 1\ntail -f $LOG_FILE\n"
  },
  {
    "path": "launch/utils.bash",
    "content": "#!/bin/bash\n\ncheck_rundir()\n{\n  LOG_DIR=\"$1\"\n  EXP_ID=\"$2\"\n\n  if [ ! -d \"$LOG_DIR\" ]; then\n    echo \"Creating directory $LOG_DIR\"\n    mkdir -p $LOG_DIR\n  else\n    LOGD=$LOG_DIR/$EXP_ID\n    if [ -d \"$LOGD\" ]; then\n      echo \"Directory $LOGD already exists.\"\n      read -p \"Do you want to remove the log files?: \" -n 1 -r\n      echo \n      if [[ ! $REPLY =~ ^[Yy]$ ]]\n      then\n        exit;\n      else\n        rm -rf $LOGD\n      fi\n    fi\n  fi\n}\n\n"
  },
  {
    "path": "models/__init__.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nfrom .resnet18 import resnet18\nfrom .net import Net\nfrom .framework import Framework\n\ndef get_model(cfg, *args, **kwargs):\n\n    backbones = {\n        'resnet18': resnet18\n    }\n\n    def create_net():\n        backbone = backbones[cfg.MODEL.ARCH.lower()](*args, **kwargs)\n        return Net(cfg, backbone)\n\n    net = create_net()\n    return Framework(cfg, net)\n"
  },
  {
    "path": "models/base.py",
    "content": "\"\"\"\nBase class for network models\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass BaseNet(nn.Module):\n\n    _trainable = (nn.Linear, nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d, nn.SyncBatchNorm)\n    _batchnorm = (nn.BatchNorm2d, nn.SyncBatchNorm, nn.GroupNorm)\n\n    def __init__(self):\n        super().__init__()\n        # we may want a different learning rate\n        # for new layers\n        self.from_scratch_layers = []\n\n        # we may want to freeze some layers\n        self.not_training = []\n\n        # we may want to freeze BN (means/stds only)\n        self.bn_freeze = []\n\n    def lr_mult(self):\n        \"\"\"Learning rate multiplier for weights.\n        Returns: [old, new]\"\"\"\n        return 1., 1.\n\n    def lr_mult_bias(self):\n        \"\"\"Learning rate multiplier for bias.\n        Returns: [old, new]\"\"\"\n        return 2., 2.\n\n    def _is_learnable(self, layer):\n         return isinstance(layer, BaseNet._trainable)\n\n    def _from_scratch(self, net, ignore=[]):\n\n        for layer in net.modules():\n            if self._is_learnable(layer):\n                self.from_scratch_layers.append(layer)\n\n    def _freeze_bn(self, net, ignore=[]):\n        \"\"\"Add layers to use in .eval() mode only\"\"\"\n\n        for layer in net.modules():\n            if isinstance(layer, BaseNet._batchnorm) and \\\n                    not layer in ignore:\n\n                assert hasattr(layer, \"eval\") and callable(layer.eval)\n                print(\"Freezing \", layer)\n                self.bn_freeze.append(layer)\n\n        print(\"Frozen BN: \", len(self.bn_freeze))\n\n    def _fix_bn(self, layer):\n        if isinstance(layer, nn.BatchNorm2d):\n            self.not_training.append(layer)\n\n        elif isinstance(layer, nn.Module):\n            for c in layer.children():\n                self._fix_bn(c)\n\n    def __set_grad_mode(self, layer, mode, only_type=None):\n\n        if hasattr(layer, \"weight\"):\n            if only_type is None or isinstance(layer, only_type):\n                layer.weight.requires_grad = mode\n\n        if hasattr(layer, \"bias\") and not layer.bias is None:\n            if only_type is None or isinstance(layer, only_type):\n                layer.bias.requires_grad = mode\n\n        if isinstance(layer, nn.Module):\n            for c in layer.children():\n                self.__set_grad_mode(c, mode)\n\n    def train(self, mode=True):\n        super().train(mode)\n\n        # some layers have to be frozen\n        for layer in self.not_training:\n            self.__set_grad_mode(layer, False)\n\n        for layer in self.bn_freeze:\n            assert hasattr(layer, \"eval\") and callable(layer.eval)\n            layer.eval()\n\n    def parameter_groups(self, base_lr, wd):\n\n        w_old, w_new = self.lr_mult()\n        b_old, b_new = self.lr_mult_bias()\n\n        groups = ({\"params\": [], \"weight_decay\":  wd, \"lr\": w_old*base_lr}, # weight learning\n                  {\"params\": [], \"weight_decay\": 0.0, \"lr\": b_old*base_lr}, # bias learning\n                  {\"params\": [], \"weight_decay\":  wd, \"lr\": w_new*base_lr}, # weight finetuning\n                  {\"params\": [], \"weight_decay\": 0.0, \"lr\": b_new*base_lr}) # bias finetuning\n\n        for m in self.modules():\n\n            if not self._is_learnable(m):\n                if hasattr(m, \"weight\") or hasattr(m, \"bias\"):\n                    print(\"Skipping layer with parameters: \", m)\n                continue\n\n            if not m.weight is None and m.weight.requires_grad:\n                if m in self.from_scratch_layers:\n                    groups[2][\"params\"].append(m.weight)\n                else:\n                    groups[0][\"params\"].append(m.weight)\n            elif not m.weight is None:\n                print(\"Skipping W: \", m, m.weight.size())\n\n            if m.bias is not None and m.bias.requires_grad:\n                if m in self.from_scratch_layers:\n                    groups[3][\"params\"].append(m.bias)\n                else:\n                    groups[1][\"params\"].append(m.bias)\n            elif m.bias is not None:\n                print(\"Skipping b: \", m, m.bias.size())\n\n        return groups\n    \n    @staticmethod\n    def _resize_as(x, y):\n        return F.interpolate(x, y.size()[-2:], mode=\"bilinear\", align_corners=True)\n"
  },
  {
    "path": "models/framework.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom models.base import BaseNet\n\nclass Framework(BaseNet):\n\n    def __init__(self, cfg, net):\n        super(Framework, self).__init__()\n\n        self.cfg = cfg\n        self.fast_net = net\n        self.eye = None\n\n    def parameter_groups(self, base_lr, wd):\n        return self.fast_net.parameter_groups(base_lr, wd)\n\n    def _align(self, x, t):\n        tf = F.affine_grid(t, size=x.size(), align_corners=False)\n        return F.grid_sample(x, tf, align_corners=False, mode=\"nearest\")\n\n    def _key_val(self, ctr, q):\n        \"\"\"\n        Args:\n            ctr: [N,K]\n            q: [BHW,K]\n        Returns:\n            val: [BHW,N]\n        \"\"\"\n\n        # [BHW,K] x [N,K].t -> [BHWxN]\n        vals = torch.mm(q, ctr.t()) # [BHW,N]\n\n        # normalising attention\n        return vals / self.cfg.TEST.TEMP\n\n    def _sample_index(self, x, T, N):\n        \"\"\"Sample indices of the anchors\n\n        Args:\n            x: [BT,K,H,W]\n        Returns:\n            index: [B,N*N,K]\n        \"\"\"\n\n        BT,K,H,W = x.shape\n        B = x.view(-1,T,K,H*W).shape[0]\n\n        # sample indices from a uniform grid\n        xs, ys = W // N, H // N\n        x_sample = torch.arange(0, W, xs).view(1, 1, N)\n        y_sample = torch.arange(0, H, ys).view(1, N, 1)\n\n        # Random offsets\n        # [B x 1 x N]\n        x_sample = x_sample + torch.randint(0, xs, (B, 1, 1))\n        # [B x N x 1]\n        y_sample = y_sample + torch.randint(0, ys, (B, 1, 1))\n\n        # batch index\n        # [B x N x N]\n        hw_index = torch.LongTensor(x_sample + y_sample * W)\n\n        return hw_index\n\n    def _sample_from(self, x, index, T, N):\n        \"\"\"Gather the features based on the index\n\n        Args:\n            x: [BT,K,H,W]\n            index: [B,N,N] defines the indices of NxN grid for a single\n                           frame in each of B video clips\n        Returns:\n            anchors: [BNN,K] sampled features given by index from x\n        \"\"\"\n\n        BT,K,H,W = x.shape\n        x = x.view(-1,T,K,H*W)\n        B = x.shape[0]\n\n        # > [B,T,K,HW] > [B,T,HW,K] > [B,THW,K]\n        x = x.permute([0,1,3,2]).reshape(B,-1,K)\n\n        # every video clip will have the same samples\n        # on the grid\n        # [B x N x N] -> [B x N*N x 1] -> [B x N*N x K]\n        index = index.view(B,-1,1).expand(-1,-1,K)\n\n        # selecting from the uniform grid\n        y = x.gather(1, index.to(x.device))\n\n        # [BNN,K]\n        return y.flatten(0,1)\n\n    def _mark_from(self, x, index, T, N, fill_value=0):\n        \"\"\"This is analogous to _sample_from except that\n        here we simply \"mark\" the sampled positions in the tensor\n        Used for visualisation only.\n        Since it is a binary mask, K == 1\n\n        Args:\n            x: [BT,1,H,W] binary mask\n            index: [B,N,N] defines the indices of NxN grid for a single\n                           frame in each of B video clips\n        Returns:\n            y: [BT,1,H,W] marked sample positions\n        \"\"\"\n\n        BT,K,H,W = x.shape\n        assert K == 1, \"Expected binary mask\"\n        x = x.view(-1,T,K,H*W)\n        B = x.shape[0]\n\n        # > [B,T,K,HW] > [B,T,HW,K] > [B,THW,K]\n        x = x.permute([0,1,3,2]).reshape(B,-1,K)\n\n        # every video clip will have the same samples\n        # on the grid\n        # [B x N x N] -> [B x N*N x 1] -> [B x N*N x K]\n        index = index.view(B,-1,1).expand(-1,-1,K)\n\n        # selecting from the uniform grid\n        # [B x T*H*W x K]\n        y = x.scatter(1, index.to(x.device), fill_value)\n\n        # [B x T*H*W x K] -> [BT x K x H x W]\n        return y.view(-1,H*W,K).permute([0,2,1]).view(-1,K,H,W)\n\n    def _cluster_grid(self, k1, k2, aff1, aff2, T, index=None):\n        \"\"\" Selecting clusters within a sequence\n        Args:\n            k1: [BT,K,H,W]\n            k2: [BT,K,H,W]\n        \"\"\"\n\n        BT,K,H,W = k1.shape\n        assert BT % T == 0, \"Batch not divisible by sequence length\"\n        B = BT // T\n\n        # N = [G x G]\n        N = self.cfg.MODEL.GRID_SIZE ** 2\n\n        # [BT,K,H,W] -> [BTHW,K]\n        flatten = lambda x: x.flatten(2,3).permute([0,2,1]).flatten(0,1)\n\n        # [BTHW,BN] -> [BT,BN,H,W]\n        def unflatten(x, aff=None):\n            x = x.view(BT,H*W,-1).permute([0,2,1]).view(BT,-1,H,W)\n            if aff is None:\n                return x\n            return self._align(x, aff)\n\n        index = self._sample_index(k1, T, N = self.cfg.MODEL.GRID_SIZE)\n        query1 = self._sample_from(k1, index, T, N = self.cfg.MODEL.GRID_SIZE)\n\n        \"\"\"Computing the distances and pseudo labels\"\"\"\n\n        # [BTHW,K]\n        k1_ = flatten(k1)\n        k2_ = flatten(k2)\n\n        # [BTHW,BN] -> [BTHW,BN] -> [BT,BN,H,W]\n        vals_soft = unflatten(self._key_val(query1, k1_), aff1)\n        vals_pseudo = unflatten(self._key_val(query1, k2_), aff2)\n\n        # [BT,BN,H,W]\n        probs_pseudo = self._pseudo_mask(vals_pseudo, T)\n        probs_pseudo2 = self._pseudo_mask(vals_soft, T)\n\n        pseudo = probs_pseudo.argmax(1)\n        pseudo2 = probs_pseudo2.argmax(1)\n\n        # mask\n        def grid_mask():\n            grid_mask = torch.ones(BT,1,H,W).to(pseudo.device)\n            return self._mark_from(grid_mask, index, T, N = self.cfg.MODEL.GRID_SIZE)\n\n        return vals_soft, pseudo, index, [vals_pseudo, pseudo2, grid_mask]\n\n    # sampling affinity\n    def _aff_sample(self, k1, k2, T):\n        BT,K,h,w = k1.size()\n        B = BT // T\n        hw = h*w\n\n        def gen(query):\n            grid_mask = torch.ones(B,1,hw).to(k1.device)\n            # generating random indices\n            indices = torch.randint(0, hw, (B,1,1)).to(k1.device)\n            grid_mask.scatter_(2, indices, 0)\n\n            # [B,K,H,W] -> [B,K,1]\n            query_ = query[::T].view(B,K,-1).gather(2, indices.expand(-1,K,-1))\n\n            def aff(keys):\n                k = keys.view(B,T,K,-1)\n                # [B,T,K,HW] x [B,1,K,HW] -> [B,T,HW]\n                aff = (k * query_[:,None,:,:]).sum(2)\n                return (aff + 1) / 2\n\n\n            aff1 = aff(k1)\n            aff2 = aff(k2)\n\n            return grid_mask.view(B,h,w), aff1.view(BT,h,w), aff2.view(BT,h,w)\n\n        grid_mask1, aff1_1, aff1_2 = gen(k1)\n        grid_mask2, aff2_1, aff2_2 = gen(k2)\n\n        return grid_mask1, aff1_1, aff1_2, \\\n                grid_mask2, aff2_1, aff2_2\n\n    def _pseudo_mask(self, logits, T):\n        BT,K,h,w = logits.shape\n        assert BT % T == 0, \"Batch not divisible by sequence length\"\n        B = BT // T\n\n        # N = [G x G]\n        N = self.cfg.MODEL.GRID_SIZE ** 2\n\n        # generating a pseudo label\n        # first we need to mask out the affinities across the batch\n        if self.eye is None or self.eye.shape[0] != B*T \\\n                            or self.eye.shape[1] != B*N:\n            eye = torch.eye(B)[:,:,None].expand(-1,-1,N).reshape(B,-1)\n            eye = eye.unsqueeze(1).expand(-1,T,-1).reshape(B*T, B*N, 1, 1)\n            self.eye = eye.to(logits.device)\n\n        probs = F.softmax(logits, 1)\n        return probs * self.eye\n\n    def _ref_loss(self, x, y, N = 4):\n        B,_,h,w = x.shape\n\n        index = self._sample_index(x, T=1, N=N)\n        x1 = self._sample_from(x, index, T=1, N=N)\n        y1 = self._sample_from(y, index, T=1, N=N)\n        logits = torch.mm(x1, y1.t()) / self.cfg.TEST.TEMP\n\n        labels = torch.arange(logits.size(1)).to(logits.device)\n        return F.cross_entropy(logits, labels)\n\n    def _ce_loss(self, x, pseudo_map, T, eps=1e-5):\n        error_map = F.cross_entropy(x, pseudo_map, reduction=\"none\", ignore_index=-1)\n\n        BT,h,w = error_map.shape\n        errors = error_map.view(-1,T,h,w)\n        error_ref, error_t = errors[:,0], errors[:,1:]\n\n        return error_ref.mean(), error_t.mean(), error_map\n\n    def _forward_reg(self, frames2, norm):\n        losses = {}\n\n        if not self.cfg.TRAIN.STOP_GRAD:\n            k2, res3, res4 = self.fast_net(frames2, norm)\n            return k2, res3, res4, losses\n\n        training = self.fast_net.training\n        if self.cfg.TRAIN.BLOCK_BN:\n            self.fast_net.eval()\n\n        with torch.no_grad():\n            k2, res3, res4 = self.fast_net(frames2, norm)\n\n        if self.cfg.TRAIN.BLOCK_BN:\n            self.fast_net.train(training)\n\n        return k2, res3, res4, losses\n\n    def fetch_first(self, x1, x2, T):\n        assert x1.shape[1:] == x2.shape[1:]\n        c,h,w = x1.shape[1:]\n\n        x1 = x1.view(-1,T+1,c,h,w)\n        x2 = x2.view(-1,T-1,c,h,w)\n\n        x2 = torch.cat((x1[:,-1:], x2), 1)\n        x1 = x1[:,:-1]\n\n        return x1.flatten(0,1), x2.flatten(0,1)\n\n    def forward(self, frames, frames2=None, mask=None, T=None, affine=None, affine2=None, embd_only=False, norm=True, dbg=False):\n        \"\"\"Extract temporal correspondences\n        Args:\n            frames: [B,T,C,H,W]\n\n        Returns:\n            losses: a dictionary with the embedding loss\n            net_outs: feature embeddings\n        \n        \"\"\"\n\n        # embedding for self-supervised learning\n        key1, res3, res4 = self.fast_net(frames, norm)\n\n        outs, losses = {}, {}\n        if embd_only: # only embedding\n            return res3, res4, key1\n        else:\n            key2, res3_2, res4_2, losses = self._forward_reg(frames2, norm)\n\n            # fetching the first frame from the second view\n            key1, key2 = self.fetch_first(key1, key2, T)\n\n            vals, pseudo, index, dbg_info = self._cluster_grid(key1, key2, affine, affine2, T)\n\n            vals_pseudo, pseudo2, grid_mask = dbg_info\n\n            key1_aligned = self._align(key1, affine)\n            key2_aligned = self._align(key2, affine2)\n\n            n_ref = self.cfg.MODEL.GRID_SIZE_REF\n            losses[\"cross_key\"] = self._ref_loss(key1_aligned[::T], key2_aligned[::T], N = n_ref)\n\n            # losses\n            _, losses[\"temp\"], outs[\"error_map\"] = self._ce_loss(vals, pseudo, T)\n\n            # computing the main loss\n            losses[\"main\"] = self.cfg.MODEL.CE_REF * losses[\"cross_key\"] + losses[\"temp\"]\n\n            if dbg:\n                vals = F.softmax(vals, 1)\n                vals_pseudo = F.softmax(vals_pseudo, 1)\n\n                frames, frames2 = self.fetch_first(frames, frames2, T)\n                outs[\"frames_orig\"] = frames\n                outs[\"frames\"] = self._align(frames, affine)\n                outs[\"frames2\"] = self._align(frames2, affine2)\n\n                outs[\"map_soft\"] = vals\n                outs[\"map\"] = pseudo\n                outs[\"map_target_soft\"] = vals_pseudo\n                outs[\"map_target\"] = pseudo2\n                outs[\"grid_mask\"] = grid_mask()\n\n                outs[\"aff_mask1\"], outs[\"aff11\"], outs[\"aff12\"], \\\n                        outs[\"aff_mask2\"], outs[\"aff21\"], outs[\"aff22\"] = self._aff_sample(key1, key2, T)\n\n        return losses, outs\n"
  },
  {
    "path": "models/net.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.base import BaseNet\n\nclass MLP(nn.Sequential):\n\n    def __init__(self, n_in, n_out):\n        super().__init__()\n\n        self.add_module(\"conv1\", nn.Conv2d(n_in, n_in, 1, 1))\n        self.add_module(\"bn1\", nn.BatchNorm2d(n_in))\n        self.add_module(\"relu\", nn.ReLU(True))\n        self.add_module(\"conv2\", nn.Conv2d(n_in, n_out, 1, 1))\n\nclass Net(BaseNet):\n\n    def __init__(self, cfg, backbone):\n        super(Net, self).__init__()\n\n        self.cfg = cfg\n        self.backbone = backbone\n        self.emb_q = MLP(backbone.fdim, cfg.MODEL.FEATURE_DIM)\n\n    def lr_mult(self):\n        \"\"\"Learning rate multiplier for weights.\n        Returns: [old, new]\"\"\"\n        return 1., 1.\n\n    def lr_mult_bias(self):\n        \"\"\"Learning rate multiplier for bias.\n        Returns: [old, new]\"\"\"\n        return 2., 2.\n\n    def forward(self, frames, norm=True):\n        \"\"\"Forward pass to extract projection and task features\"\"\"\n\n        # extracting the time dimension\n        res4, res3 = self.backbone(frames)\n\n        # B,K,H,W\n        query = self.emb_q(res4)\n\n        if norm:\n            query = F.normalize(query, p=2, dim=1)\n            res3 = F.normalize(res3, p=2, dim=1)\n            res4 = F.normalize(res4, p=2, dim=1)\n\n        return query, res3, res4\n"
  },
  {
    "path": "models/resnet18.py",
    "content": "\"\"\"\nBased on Jabri et al., (2020)\nCredit: https://github.com/ajabri/videowalk.git\nLicense: MIT\n\"\"\"\n\nimport os\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport torchvision.models.resnet as torch_resnet\nfrom torchvision.models.resnet import BasicBlock\n\nclass ResNet(torch_resnet.ResNet):\n\n    def __init__(self, *args, **kwargs):\n        super(ResNet, self).__init__(*args, **kwargs)\n\n    def filter_layers(self, x):\n        return [l for l in x if getattr(self, l) is not None]\n\n    def remove_layers(self, remove_layers=[]):\n        # Remove extraneous layers\n        remove_layers += ['fc', 'avgpool']\n        for layer in self.filter_layers(remove_layers):\n            setattr(self, layer, None)\n\n    def modify(self):\n\n        # Set stride of layer3 and layer 4 to 1 (from 2)\n        for layer in self.filter_layers(['layer3']):\n            for m in getattr(self, layer).modules():\n                if isinstance(m, torch.nn.Conv2d):\n                    m.stride = (1, 1)\n\n        for layer in self.filter_layers(['layer4']):\n            for m in getattr(self, layer).modules():\n                if isinstance(m, torch.nn.Conv2d):\n                    m.stride = (1, 1)\n\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = x if self.maxpool is None else self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x3 = self.layer3(x) \n        x4 = self.layer4(x3)\n\n        return x4, x3\n\ndef _resnet(arch, block, layers, pretrained, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    return model\n\ndef resnet18(pretrained='', remove_layers=[], train=True, **kwargs):\n\n    model = _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs)\n    model.modify()\n\n    model.remove_layers(remove_layers)\n    setattr(model, \"fdim\", 512)\n    return model\n"
  },
  {
    "path": "opts.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nfrom __future__ import print_function\n\nimport os\nimport torch\nimport argparse\nfrom core.config import cfg\n\ndef add_global_arguments(parser):\n\n    #\n    # Model details\n    #\n    parser.add_argument(\"--snapshot-dir\", type=str, default='./snapshots',\n                        help=\"Where to save snapshots of the model.\")\n    parser.add_argument(\"--logdir\", type=str, default='./logs',\n                        help=\"Where to save log files of the model.\")\n    parser.add_argument(\"--exp\", type=str, default=\"main\",\n                        help=\"ID of the experiment (multiple runs)\")\n    parser.add_argument(\"--run\", type=str, help=\"ID of the run\")\n    parser.add_argument('--workers', type=int, default=8,\n                        metavar='N', help='dataloader threads')\n    parser.add_argument('--seed', default=64, type=int, help='seed for initializing training. ')\n\n    # \n    # Inference only\n    #\n    parser.add_argument(\"--infer-list\", default=\"voc12/val.txt\", type=str)\n    parser.add_argument('--mask-output-dir', type=str, default=None, help='path where to save masks')\n    parser.add_argument(\"--resume\", type=str, default=None, help=\"Snapshot \\\"ID,iter\\\" to load\")\n\n    #\n    # Configuration\n    #\n    parser.add_argument(\n        '--cfg', dest='cfg_file', required=True,\n        help='Config file for training (and optionally testing)')\n    parser.add_argument(\n        '--set', dest='set_cfgs',\n        help='Set config keys. Key value sequence seperate by whitespace.'\n             'e.g. [key] [value] [key] [value]',\n        default=[], nargs='+')\n\ndef maybe_create_dir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\ndef check_global_arguments(args):\n\n    args.cuda = torch.cuda.is_available()\n    print(\"Available threads: \", torch.get_num_threads())\n\n    args.logdir = os.path.join(args.logdir, args.exp, args.run)\n    maybe_create_dir(args.logdir)\n\n    #\n    # Model directories\n    #\n    args.snapshot_dir = os.path.join(args.snapshot_dir, args.exp, args.run)\n    maybe_create_dir(args.snapshot_dir)\n\ndef get_arguments(args_in):\n    \"\"\"Parse all the arguments provided from the CLI.\n    \n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Dense Unsupervised Learning for Video Segmentation\")\n\n    add_global_arguments(parser)\n    args = parser.parse_args(args_in)\n    check_global_arguments(args)\n\n    return args\n"
  },
  {
    "path": "requirements.txt",
    "content": "# This file may be used to create an environment using:\n# $ conda create --name <env> --file <this file>\n# platform: linux-64\nsetproctitle\nmatplotlib\ntensorboard\npyyaml\npackaging\nopencv-python\nscikit-image\n"
  },
  {
    "path": "train.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nfrom __future__ import print_function\n\nimport os\nimport sys\nimport numpy as np\nimport time\nimport random\nimport setproctitle\n\nfrom functools import partial\n\nimport torch\nimport torch.nn.functional as F\n\nfrom datasets import *\nfrom models import get_model\n\nfrom base_trainer import BaseTrainer\n\nfrom opts import get_arguments\nfrom core.config import cfg, cfg_from_file, cfg_from_list\nfrom utils.timer import Timer\nfrom utils.stat_manager import StatManager\nfrom utils.davis2017 import evaluate_semi\nfrom labelprop.crw import CRW\n\nfrom torch.utils.tensorboard import SummaryWriter\n\ntorch.backends.cudnn.benchmark = True\ntorch.backends.cudnn.deterministic = False\n\nclass Trainer(BaseTrainer):\n\n    def __init__(self, args, cfg):\n\n        super(Trainer, self).__init__(args, cfg)\n\n        # train loader for target domain\n        self.loader = get_dataloader(args, cfg, 'train')\n\n        # alias\n        self.denorm = self.loader.dataset.denorm \n\n        # val loaders for source and target domains\n        self.valloaders = get_dataloader(args, cfg, 'val')\n\n        # writers (only main)\n        self.writer_val = {}\n        for val_set in self.valloaders.keys():\n            logdir_val = os.path.join(args.logdir, val_set)\n            self.writer_val[val_set] = SummaryWriter(logdir_val)\n\n        # model\n        self.net = get_model(cfg, remove_layers=cfg.MODEL.REMOVE_LAYERS)\n\n        print(\"Train Net: \")\n        print(self.net)\n\n        # optimizer using different LR\n        net_params = self.net.parameter_groups(cfg.MODEL.LR, cfg.MODEL.WEIGHT_DECAY)\n\n        print(\"Optimising parameter groups: \")\n        for i, g in enumerate(net_params):\n            print(\"[{}]: # parameters: {}, lr = {:4.3e}\".format(i, len(g[\"params\"]), g[\"lr\"]))\n\n        self.optim = self.get_optim(net_params, cfg.MODEL)\n\n        print(\"# of params: \", len(list(self.net.parameters())))\n\n        # LR scheduler\n        if cfg.MODEL.LR_SCHEDULER == \"step\":\n            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, \\\n                                                             step_size=cfg.MODEL.LR_STEP, \\\n                                                             gamma=cfg.MODEL.LR_GAMMA)\n        elif cfg.MODEL.LR_SCHEDULER == \"linear\": # linear decay\n\n            def lr_lambda(epoch):\n                mult = 1 - epoch / (float(self.cfg.TRAIN.NUM_EPOCHS) - self.start_epoch)\n                mult = mult ** self.cfg.MODEL.LR_POWER\n                #print(\"Linear Scheduler: mult = {:4.3f}\".format(mult))\n                return mult\n\n            self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda)\n        else:\n            self.scheduler = None\n\n        self.vis_batch = None\n\n        # using cuda\n        self.net.cuda()\n        self.crw = CRW(cfg.TEST)\n\n        # checkpoint management\n        self.checkpoint.create_model(self.net, self.optim)\n        if not args.resume is None:\n            self.start_epoch, self.best_score = self.checkpoint.load(args.resume, \"cuda:0\")\n\n\n    def step_seg(self, epoch, batch_src, key, temp=None, train=False, visualise=False, \\\n                 save_batch=False, writer=None, tag=\"train_src\"):\n\n        frames, masks_gt, n_obj, seq_name = batch_src\n\n        # semi-supervised: select only the first\n        frames = frames.flatten(0,1)\n        masks_gt = masks_gt.flatten(0,1)\n        masks_gt = masks_gt[:, :n_obj.item()]\n\n        masks_ref = masks_gt.clone()\n        masks_ref[1:] *= 0\n\n        T = frames.shape[0]\n\n        fetch = {\"res3\": lambda x: x[0], \\\n                 \"res4\": lambda x: x[1], \\\n                 \"key\": lambda x: x[2]}\n\n        # number of iterations\n        bs = self.cfg.TRAIN.BATCH_SIZE\n        feats = []\n        t0 = time.time()\n\n        torch.cuda.empty_cache()\n\n        for t in range(0, T, bs):\n\n            # next frame\n            frames_batch = frames[t:t+bs].cuda()\n\n            # source forward pass\n            feats_ = self.net(frames_batch, embd_only=True)\n            feats.append(fetch[key](feats_).cpu())\n\n        feats = torch.cat(feats, 0)\n        print(\"Inference: {:4.3f}s\".format(time.time() - t0))\n        sys.stdout.flush()\n        t0 = time.time()\n        outs = self.crw.forward(feats, masks_ref)\n        print(\"CRW propagation: {:4.3f}s\".format(time.time() - t0))\n        sys.stdout.flush()\n        outs[\"masks_gt\"] = masks_gt.argmax(1)\n\n        if visualise:\n            outs[\"frames\"] = frames\n            self._visualise_seg(epoch, outs, writer, tag)\n\n        if save_batch:\n            # Saving batch for visualisation\n            # saving the batch for visualisation\n            self.save_vis_batch(tag, batch_src)\n\n        return outs\n\n    def step(self, epoch, batch_in, train=False, visualise=False, save_batch=False, writer=None, tag=\"train\"):\n\n        frames1, frames2, affine1, affine2 = batch_in\n        assert frames1.size() == frames2.size(), \"Frames shape mismatch\"\n\n        # We could simply do\n        #   images1 = frames1.flatten(0,1).cuda()\n        #   images2 = frames2.flatten(0,1).cuda()\n        # Instead we pull the reference frame from the 2nd view\n        # to the first view so that the regularising branch is \n        # always in evaluation mode to save the GPU memory\n\n        B,T,C,H,W = frames1.shape\n        images1 = torch.cat((frames1, frames2[:, ::T]), 1)\n        images1 = images1.flatten(0,1).cuda()\n        images2 = frames2[:, 1:].flatten(0,1).cuda()\n\n        affine1 = affine1.flatten(0,1).cuda()\n        affine2 = affine2.flatten(0,1).cuda()\n\n        # source forward pass\n        losses, outs = self.net(images1, frames2=images2, T=T, \\\n                                affine=affine1, affine2=affine2, \\\n                                dbg=visualise)\n\n        if train:\n            self.optim.zero_grad()\n            losses[\"main\"].backward()\n            self.optim.step()\n\n        if visualise:\n            self._visualise(epoch, outs, T, writer, tag)\n\n        if save_batch:\n            # Saving batch for visualisation\n            self.save_vis_batch(tag, batch_in)\n\n        # summarising the losses\n        # into python scalars\n        losses_ret = {}\n        for key, val in losses.items():\n            losses_ret[key] = val.mean().item()\n\n        return losses_ret, outs\n\n    def train_epoch(self, epoch):\n\n        stat = StatManager()\n\n        # adding stats for classes\n        timer = Timer(\"Epoch {}\".format(epoch))\n        step = partial(self.step, train=True, visualise=False)\n\n        # training mode\n        self.net.train()\n\n        for i, batch in enumerate(self.loader):\n\n            save_batch = i == 0\n\n            #\n            # Forward pass\n            #\n            losses, _ = step(epoch, batch, save_batch=save_batch, tag=\"train\")\n\n            for loss_key, loss_val in losses.items():\n                stat.update_stats(loss_key, loss_val)\n\n            # intermediate logging\n            if i % 10 == 0:\n                msg =  \"Loss [{:04d}]: \".format(i)\n                for loss_key, loss_val in losses.items():\n                    msg += \" {} {:.4f} | \".format(loss_key, loss_val)\n                msg += \" | Im/Sec: {:.1f}\".format(i * self.cfg.TRAIN.BATCH_SIZE / timer.get_stage_elapsed())\n                print(msg)\n                sys.stdout.flush()\n        \n        for name, val in stat.items():\n            print(\"{}: {:4.3f}\".format(name, val))\n            self.writer.add_scalar('all/{}'.format(name), val, epoch)\n\n        # plotting learning rate\n        for ii, l in enumerate(self.optim.param_groups):\n            print(\"Learning rate [{}]: {:4.3e}\".format(ii, l['lr']))\n            self.writer.add_scalar('lr/enc_group_%02d' % ii, l['lr'], epoch)\n\n        # plotting moment distance\n        if stat.has_vals(\"lr_gamma\"):\n            self.writer.add_scalar('hyper/gamma', stat.summarize_key(\"lr_gamma\"), epoch)\n\n        if epoch % self.cfg.LOG.ITER_TRAIN == 0:\n            self.visualise_results(epoch, self.writer, \"train\", self.step)\n\n    def validation(self, epoch, writer, loader, tag=None, max_iter=None):\n\n        stat = StatManager()\n\n        if max_iter is None:\n            max_iter = len(loader)\n\n        # Fast test during the training\n        def eval_batch(batch):\n\n            loss, masks = self.step(epoch, batch, train=False, visualise=False)\n\n            for loss_key, loss_val in loss.items():\n                stat.update_stats(loss_key, loss_val)\n\n            return masks\n\n        self.net.eval()\n\n        print(\"Starting validation\")\n        sys.stdout.flush()\n        for n, batch in enumerate(loader):\n\n            with torch.no_grad():\n                # note video length assumed 1\n                eval_batch(batch)\n\n            if not tag is None and not self.has_vis_batch(tag):\n                self.save_vis_batch(tag, batch)\n\n        checkpoint_score = 0.0\n\n        # total classification loss\n        for stat_key, stat_val in stat.items():\n            writer.add_scalar('all/{}'.format(stat_key), stat_val, epoch)\n            print('Loss {}: {:4.3f}'.format(stat_key, stat_val))\n\n        if not tag is None and epoch % self.cfg.LOG.ITER_TRAIN == 0:\n            self.visualise_results(epoch, writer, tag, self.step)\n\n        return checkpoint_score\n\n    def validation_seg(self, epoch, writer, loader, key=\"all\", temp=None, tag=None, max_iter=None):\n\n        vis = key == \"res4\"\n        stat = StatManager()\n\n        if max_iter is None:\n            max_iter = len(loader)\n\n        if temp is None:\n            temp = self.cfg.TEST.TEMP\n\n        step_fn = partial(self.step_seg, key=key, temp=temp, train=False, visualise=vis, writer=writer)\n\n        # Fast test during the training\n        def eval_batch(n, batch):\n            tag_n = tag + \"_{:02d}\".format(n)\n            masks = step_fn(epoch, batch, tag=tag_n)\n            return masks\n\n        self.net.eval()\n\n        def davis_mask(masks):\n            masks = masks.cpu() \n            num_objects = int(masks.max())\n            tmp = torch.ones(num_objects, *masks.shape)\n            tmp = tmp * torch.arange(1, num_objects + 1)[:, None, None, None]\n            return (tmp == masks[None, ...]).long().numpy()\n\n        Js = {\"M\": [], \"R\": [], \"D\": []}\n        Fs = {\"M\": [], \"R\": [], \"D\": []}\n\n        timer = Timer(\"[Epoch {}] Validation-Seg\".format(epoch))\n        tag_key = \"{}_{}_{:3.2f}\".format(tag, key, temp)\n        for n, batch in enumerate(loader):\n            seq_name = batch[-1][0]\n            print(\"Sequence: \", seq_name)\n            sys.stdout.flush()\n\n            with torch.no_grad():\n                masks_out = eval_batch(n, batch)\n\n            # second element is assumed to be always GT masks\n            masks_gt = davis_mask(masks_out[\"masks_gt\"])\n            masks_pred = davis_mask(masks_out[\"masks_pred_idx\"])\n            assert masks_gt.shape == masks_pred.shape\n\n            # converting to a digestible format\n            # [num_objects, seq_length, height, width]\n\n            if not tag_key is None and not self.has_vis_batch(tag_key):\n                self.save_vis_batch(tag_key, batch)\n\n            start_t = time.time()\n            metrics_res = evaluate_semi((masks_gt, ), (masks_pred, ))\n            J, F = metrics_res['J'], metrics_res['F']\n\n            print(\"Evaluation: {:4.3f}s\".format(time.time() - start_t))\n            print(\"Jaccard: \", J[\"M\"])\n            print(\"F-Score: \", F[\"M\"])\n\n            for l in (\"M\", \"R\", \"D\"):\n                Js[l] += J[l]\n                Fs[l] += F[l]\n\n            msg = \"{} | Im/Sec: {:.1f}\".format(n, n * batch[0].shape[1] / timer.get_stage_elapsed())\n            print(msg)\n            sys.stdout.flush()\n\n        g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay']\n\n        # Generate dataframe for the general results\n        final_mean = (np.mean(Js[\"M\"]) + np.mean(Fs[\"M\"])) / 2.\n        g_res = [final_mean, \\\n                 np.mean(Js[\"M\"]), np.mean(Js[\"R\"]), np.mean(Js[\"D\"]), \\\n                 np.mean(Fs[\"M\"]), np.mean(Fs[\"R\"]), np.mean(Fs[\"D\"])]\n\n        for (name, val) in zip(g_measures, g_res):\n            writer.add_scalar('{}_{:3.2f}/{}'.format(key, temp, name), val, epoch)\n            print('{}: {:4.3f}'.format(name, val))\n\n        return final_mean\n\n\ndef train(args, cfg):\n\n    setproctitle.setproctitle(\"dense-ulearn | {}\".format(args.run))\n\n    if args.seed is not None:\n        print(\"Setting the seed: {}\".format(args.seed))\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        torch.cuda.manual_seed(args.seed)\n        torch.cuda.manual_seed_all(args.seed)\n\n    trainer = Trainer(args, cfg)\n\n    timer = Timer()\n    def time_call(func, msg, *args, **kwargs):\n        timer.reset_stage()\n        val = func(*args, **kwargs)\n        print(msg + (\" {:3.2}m\".format(timer.get_stage_elapsed() / 60.)))\n        return val\n\n    for epoch in range(trainer.start_epoch, cfg.TRAIN.NUM_EPOCHS + 1):\n\n        # training 1 epoch\n        time_call(trainer.train_epoch, \"Train epoch: \", epoch)\n\n        print(\"Epoch >>> {:02d} <<< \".format(epoch))\n        if epoch % cfg.LOG.ITER_VAL == 0:\n\n            for val_set in (\"val_video\", ):\n                time_call(trainer.validation, \"Validation / {} /  Val: \".format(val_set), \\\n                          epoch, trainer.writer_val[val_set], trainer.valloaders[val_set], tag=val_set)\n\n            best_layer = None\n            best_score = -1e10\n            for val_set in (\"val_video_seg\", ):\n                writer = trainer.writer_val[val_set]\n                loader = trainer.valloaders[val_set]\n                for layer in (\"key\", \"res4\"):\n                    msg = \">>> Validation {} / {} <<<\".format(layer, val_set)\n                    score = time_call(trainer.validation_seg, msg, epoch, writer, loader, key=layer, tag=val_set)\n                    if score > best_score:\n                        best_score = score\n                        best_layer = layer\n                \n                print(\"Best score / layer: {:4.2f} / {}\".format(best_score, best_layer))\n                if val_set ==\"val_video_seg\":\n                    trainer.checkpoint_best(best_score, epoch, best_layer)\n\n        if not trainer.scheduler is None and cfg.MODEL.LR_SCHED_USE_EPOCH:\n            trainer.scheduler.step()\n\ndef main():\n    args = get_arguments(sys.argv[1:])\n\n    # Reading the config\n    cfg_from_file(args.cfg_file)\n    if args.set_cfgs is not None:\n        cfg_from_list(args.set_cfgs)\n\n    train(args, cfg)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "utils/checkpoints.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\nimport sys\nimport torch\n\n\nclass Checkpoint(object):\n\n    def __init__(self, path, max_n=3):\n        self.path = path\n        self.max_n = max_n\n        self.models = {}\n        self.checkpoints = []\n\n    def create_model(self, model, opt):\n        self.models = {}\n        self.models['model'] = model\n        self.models['opt'] = opt\n\n    def limit(self):\n        return self.max_n\n\n    def __len__(self):\n        return len(self.checkpoints)\n\n    def _get_full_path(self, suffix):\n        filename = self._filename(suffix)\n        return os.path.join(self.path, filename)\n\n    def clean(self):\n        n_remove = max(0, len(self.checkpoints) - self.max_n)\n        for i in range(n_remove):\n            self._rm(self.checkpoints[i])\n        self.checkpoints = self.checkpoints[n_remove:]\n\n    def _rm(self, suffix):\n        path = self._get_full_path(suffix)\n        if os.path.isfile(path):\n            os.remove(path)\n\n    def _filename(self, suffix):\n        return \"{}.pth\".format(suffix)\n\n    def load(self, path, location):\n        if len(path) > 0 and not os.path.isfile(path):\n            print(\"Snapshot {} not found\".format(path))\n\n        data = torch.load(path, map_location=location)\n        data_mapped = {}\n        for key, val in data[\"model\"].items():\n            data_mapped[key.replace(\"module.\", \"\")] = val\n\n        self.models[\"model\"].load_state_dict(data_mapped, strict=True)\n        return data[\"epoch\"], data[\"score\"]\n\n    def checkpoint(self, score, epoch, t):\n        suffix = \"epoch{:03d}_score{:4.3f}_{}\".format(epoch, score, t)\n        self.checkpoints.append(suffix)\n\n        path = self._get_full_path(suffix)\n        if not os.path.isfile(path):\n            torch.save({\"model\": self.models[\"model\"].state_dict(),\n                        \"opt\": self.models[\"opt\"].state_dict(),\n                        \"score\": score,\n                        \"epoch\": epoch}, path)\n\n        # removing if more than allowed number of snapshots\n        self.clean()\n"
  },
  {
    "path": "utils/collections.py",
    "content": "# Copyright (c) 2017-present, Facebook, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n##############################################################################\n\n\"\"\"A simple attribute dictionary used for representing configuration options.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom __future__ import unicode_literals\n\n\nclass AttrDict(dict):\n\n    IMMUTABLE = '__immutable__'\n\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__[AttrDict.IMMUTABLE] = False\n\n    def __getattr__(self, name):\n        if name in self.__dict__:\n            return self.__dict__[name]\n        elif name in self:\n            return self[name]\n        else:\n            raise AttributeError(name)\n\n    def __setattr__(self, name, value):\n        if not self.__dict__[AttrDict.IMMUTABLE]:\n            if name in self.__dict__:\n                self.__dict__[name] = value\n            else:\n                self[name] = value\n        else:\n            raise AttributeError(\n                'Attempted to set \"{}\" to \"{}\", but AttrDict is immutable'.\n                format(name, value)\n            )\n\n    def immutable(self, is_immutable):\n        \"\"\"Set immutability to is_immutable and recursively apply the setting\n        to all nested AttrDicts.\n        \"\"\"\n        self.__dict__[AttrDict.IMMUTABLE] = is_immutable\n        # Recursively set immutable state\n        for v in self.__dict__.values():\n            if isinstance(v, AttrDict):\n                v.immutable(is_immutable)\n        for v in self.values():\n            if isinstance(v, AttrDict):\n                v.immutable(is_immutable)\n\n    def is_immutable(self):\n        return self.__dict__[AttrDict.IMMUTABLE]\n"
  },
  {
    "path": "utils/davis2017.py",
    "content": "\"\"\"\nCredit: https://github.com/davisvideochallenge/davis2017-evaluation.git\nLicense: BSD 3-Clause\nCopyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation\n\"\"\"\n\nimport sys\nimport numpy as np\nfrom utils.davis2017_metrics import db_eval_boundary, db_eval_iou\nimport utils.davis2017_utils as utils\n\ndef _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric):\n    if all_res_masks.shape[0] > all_gt_masks.shape[0]:\n        sys.stdout.write(\"\\nIn your PNG files there is an index higher than the number of objects in the sequence!\")\n        sys.exit()\n    elif all_res_masks.shape[0] < all_gt_masks.shape[0]:\n        sys.stdout.write(\"\\nThe number of predictions is less than ground truth. Padding with zero.\")\n        zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))\n        all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)\n    j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2])\n    for ii in range(all_gt_masks.shape[0]):\n\n        if 'J' in metric:\n            j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)\n            sys.stdout.flush()\n\n        if 'F' in metric:\n            f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)\n            sys.stdout.flush()\n\n    return j_metrics_res, f_metrics_res\n\n\ndef evaluate_semi(all_gt_masks, all_res_masks, metric=('J', 'F'), debug=False):\n    metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric]\n    if 'T' in metric:\n        raise ValueError('Temporal metric not supported!')\n    if 'J' not in metric and 'F' not in metric:\n        raise ValueError('Metric possible values are J for IoU or F for Boundary')\n\n    # Containers\n    metrics_res = {}\n    if 'J' in metric:\n        metrics_res['J'] = {\"M\": [], \"R\": [], \"D\": [], \"M_per_object\": {}}\n    if 'F' in metric:\n        metrics_res['F'] = {\"M\": [], \"R\": [], \"D\": [], \"M_per_object\": {}}\n\n    for seq, (seq_gt_masks, seq_res_masks) in enumerate(zip(all_gt_masks, all_res_masks)):\n\n        seq_gt_masks, seq_res_masks = seq_gt_masks[:, 1:-1, :, :], seq_res_masks[:, 1:-1, :, :]\n        j_metrics_res, f_metrics_res = _evaluate_semisupervised(seq_gt_masks, seq_res_masks, None, metric)\n\n        for ii in range(seq_gt_masks.shape[0]):\n            seq_name = f'{seq}_{ii+1}'\n            if 'J' in metric:\n                [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii])\n                metrics_res['J'][\"M\"].append(JM)\n                metrics_res['J'][\"R\"].append(JR)\n                metrics_res['J'][\"D\"].append(JD)\n                metrics_res['J'][\"M_per_object\"][seq_name] = JM\n            if 'F' in metric:\n                [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii])\n                metrics_res['F'][\"M\"].append(FM)\n                metrics_res['F'][\"R\"].append(FR)\n                metrics_res['F'][\"D\"].append(FD)\n                metrics_res['F'][\"M_per_object\"][seq_name] = FM\n\n        # Show progress\n        if debug:\n            sys.stdout.write(seq + '\\n')\n            sys.stdout.flush()\n\n    return metrics_res\n"
  },
  {
    "path": "utils/davis2017_metrics.py",
    "content": "\"\"\"\nCredit: https://github.com/davisvideochallenge/davis2017-evaluation.git\nLicense: BSD 3-Clause\nCopyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation\n\"\"\"\n\nimport math\nimport sys\nimport numpy as np\nimport cv2\n\ncv2.setNumThreads(0)\n\nfrom skimage.morphology import disk\n\n\ndef db_eval_iou(annotation, segmentation, void_pixels=None):\n    \"\"\" Compute region similarity as the Jaccard Index.\n    Arguments:\n        annotation   (ndarray): binary annotation   map.\n        segmentation (ndarray): binary segmentation map.\n        void_pixels  (ndarray): optional mask with void pixels\n    Return:\n        jaccard (float): region similarity\n    \"\"\"\n    assert annotation.shape == segmentation.shape, \\\n        f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'\n    annotation = annotation.astype(np.bool)\n    segmentation = segmentation.astype(np.bool)\n\n    if void_pixels is not None:\n        assert annotation.shape == void_pixels.shape, \\\n            f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'\n        void_pixels = void_pixels.astype(np.bool)\n    else:\n        void_pixels = np.zeros_like(segmentation)\n\n    # Intersection between all sets\n    inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))\n    union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))\n\n    j = inters / union\n    if j.ndim == 0:\n        j = 1 if np.isclose(union, 0) else j\n    else:\n        j[np.isclose(union, 0)] = 1\n\n    return j\n\n\ndef db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):\n    assert annotation.shape == segmentation.shape\n    if void_pixels is not None:\n        assert annotation.shape == void_pixels.shape\n    if annotation.ndim == 3:\n        n_frames = annotation.shape[0]\n        f_res = np.zeros(n_frames)\n        for frame_id in range(n_frames):\n            void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]\n            f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)\n    elif annotation.ndim == 2:\n        f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)\n    else:\n        raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')\n    return f_res\n\n\ndef f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):\n    \"\"\"\n    Compute mean,recall and decay from per-frame evaluation.\n    Calculates precision/recall for boundaries between foreground_mask and\n    gt_mask using morphological operators to speed it up.\n    Arguments:\n        foreground_mask (ndarray): binary segmentation image.\n        gt_mask         (ndarray): binary annotated image.\n        void_pixels     (ndarray): optional mask with void pixels\n    Returns:\n        F (float): boundaries F-measure\n    \"\"\"\n    assert np.atleast_3d(foreground_mask).shape[2] == 1\n    if void_pixels is not None:\n        void_pixels = void_pixels.astype(np.bool)\n    else:\n        void_pixels = np.zeros_like(foreground_mask).astype(np.bool)\n\n    bound_pix = bound_th if bound_th >= 1 else \\\n        np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))\n\n    # Get the pixel boundaries of both masks\n    fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))\n    gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))\n\n    # fg_dil = binary_dilation(fg_boundary, disk(bound_pix))\n    fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))\n    # gt_dil = binary_dilation(gt_boundary, disk(bound_pix))\n    gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))\n\n    # Get the intersection\n    gt_match = gt_boundary * fg_dil\n    fg_match = fg_boundary * gt_dil\n\n    # Area of the intersection\n    n_fg = np.sum(fg_boundary)\n    n_gt = np.sum(gt_boundary)\n\n    # % Compute precision and recall\n    if n_fg == 0 and n_gt > 0:\n        precision = 1\n        recall = 0\n    elif n_fg > 0 and n_gt == 0:\n        precision = 0\n        recall = 1\n    elif n_fg == 0 and n_gt == 0:\n        precision = 1\n        recall = 1\n    else:\n        precision = np.sum(fg_match) / float(n_fg)\n        recall = np.sum(gt_match) / float(n_gt)\n\n    # Compute F measure\n    if precision + recall == 0:\n        F = 0\n    else:\n        F = 2 * precision * recall / (precision + recall)\n\n    return F\n\n\ndef _seg2bmap(seg, width=None, height=None):\n    \"\"\"\n    From a segmentation, compute a binary boundary map with 1 pixel wide\n    boundaries.  The boundary pixels are offset by 1/2 pixel towards the\n    origin from the actual segment boundary.\n    Arguments:\n        seg     : Segments labeled from 1..k.\n        width\t  :\tWidth of desired bmap  <= seg.shape[1]\n        height  :\tHeight of desired bmap <= seg.shape[0]\n    Returns:\n        bmap (ndarray):\tBinary boundary map.\n     David Martin <dmartin@eecs.berkeley.edu>\n     January 2003\n    \"\"\"\n\n    seg = seg.astype(np.bool)\n    seg[seg > 0] = 1\n\n    assert np.atleast_3d(seg).shape[2] == 1\n\n    width = seg.shape[1] if width is None else width\n    height = seg.shape[0] if height is None else height\n\n    h, w = seg.shape[:2]\n\n    ar1 = float(width) / float(height)\n    ar2 = float(w) / float(h)\n\n    assert not (\n        width > w | height > h | abs(ar1 - ar2) > 0.01\n    ), \"Can\" \"t convert %dx%d seg to %dx%d bmap.\" % (w, h, width, height)\n\n    e = np.zeros_like(seg)\n    s = np.zeros_like(seg)\n    se = np.zeros_like(seg)\n\n    e[:, :-1] = seg[:, 1:]\n    s[:-1, :] = seg[1:, :]\n    se[:-1, :-1] = seg[1:, 1:]\n\n    b = seg ^ e | seg ^ s | seg ^ se\n    b[-1, :] = seg[-1, :] ^ e[-1, :]\n    b[:, -1] = seg[:, -1] ^ s[:, -1]\n    b[-1, -1] = 0\n\n    if w == width and h == height:\n        bmap = b\n    else:\n        bmap = np.zeros((height, width))\n        for x in range(w):\n            for y in range(h):\n                if b[y, x]:\n                    j = 1 + math.floor((y - 1) + height / h)\n                    i = 1 + math.floor((x - 1) + width / h)\n                    bmap[j, i] = 1\n\n    return bmap\n\n\nif __name__ == '__main__':\n    from davis2017.davis import DAVIS\n    from davis2017.results import Results\n\n    dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics')\n    results = Results(root_dir='examples/osvos')\n    # Test timing F measure\n    for seq in dataset.get_sequences():\n        all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True)\n        all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1]\n        all_res_masks = results.read_masks(seq, all_masks_id)\n        f_metrics_res = np.zeros(all_gt_masks.shape[:2])\n        for ii in range(all_gt_masks.shape[0]):\n            f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...])\n\n    # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py\n    #                            snakeviz f_measure.prof\n"
  },
  {
    "path": "utils/davis2017_utils.py",
    "content": "\"\"\"\nCredit: https://github.com/davisvideochallenge/davis2017-evaluation.git\nLicense: BSD 3-Clause\nCopyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation\n\"\"\"\n\nimport os\nimport errno\nimport numpy as np\nfrom PIL import Image\nimport warnings\n\n\ndef _pascal_color_map(N=256, normalized=False):\n    \"\"\"\n    Python implementation of the color map function for the PASCAL VOC data set.\n    Official Matlab version can be found in the PASCAL VOC devkit\n    http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit\n    \"\"\"\n\n    def bitget(byteval, idx):\n        return (byteval & (1 << idx)) != 0\n\n    dtype = 'float32' if normalized else 'uint8'\n    cmap = np.zeros((N, 3), dtype=dtype)\n    for i in range(N):\n        r = g = b = 0\n        c = i\n        for j in range(8):\n            r = r | (bitget(c, 0) << 7 - j)\n            g = g | (bitget(c, 1) << 7 - j)\n            b = b | (bitget(c, 2) << 7 - j)\n            c = c >> 3\n\n        cmap[i] = np.array([r, g, b])\n\n    cmap = cmap / 255 if normalized else cmap\n    return cmap\n\n\ndef overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):\n    im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)\n    if im.shape[:-1] != ann.shape:\n        raise ValueError('First two dimensions of `im` and `ann` must match')\n    if im.shape[-1] != 3:\n        raise ValueError('im must have three channels at the 3 dimension')\n\n    colors = colors or _pascal_color_map()\n    colors = np.asarray(colors, dtype=np.uint8)\n\n    mask = colors[ann]\n    fg = im * alpha + (1 - alpha) * mask\n\n    img = im.copy()\n    img[ann > 0] = fg[ann > 0]\n\n    if contour_thickness:  # pragma: no cover\n        import cv2\n        for obj_id in np.unique(ann[ann > 0]):\n            contours = cv2.findContours((ann == obj_id).astype(\n                np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]\n            cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),\n                             contour_thickness)\n    return img\n\n\ndef generate_obj_proposals(davis_root, subset, num_proposals, save_path):\n    dataset = DAVIS(davis_root, subset=subset, codalab=True)\n    for seq in dataset.get_sequences():\n        save_dir = os.path.join(save_path, seq)\n        if os.path.exists(save_dir):\n            continue\n        all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True)\n        img_size = all_gt_masks.shape[2:]\n        num_rows = int(np.ceil(np.sqrt(num_proposals)))\n        proposals = np.zeros((num_proposals, len(all_masks_id), *img_size))\n        height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist()\n        width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist()\n        ii = 0\n        prev_h, prev_w = 0, 0\n        for h in height_slices[1:]:\n            for w in width_slices[1:]:\n                proposals[ii, :, prev_h:h, prev_w:w] = 1\n                prev_w = w\n                ii += 1\n                if ii == num_proposals:\n                    break\n            prev_h, prev_w = h, 0\n            if ii == num_proposals:\n                break\n\n        os.makedirs(save_dir, exist_ok=True)\n        for i, mask_id in enumerate(all_masks_id):\n            mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0)\n            save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))\n\n\ndef generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path):\n    dataset = DAVIS(davis_root, subset=subset, codalab=True)\n    for seq in dataset.get_sequences():\n        gt_masks, all_masks_id = dataset.get_all_masks(seq, True)\n        obj_swap = np.random.permutation(np.arange(gt_masks.shape[0]))\n        gt_masks = gt_masks[obj_swap, ...]\n        save_dir = os.path.join(save_path, seq)\n        os.makedirs(save_dir, exist_ok=True)\n        for i, mask_id in enumerate(all_masks_id):\n            mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0)\n            save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))\n\n\ndef color_map(N=256, normalized=False):\n    def bitget(byteval, idx):\n        return ((byteval & (1 << idx)) != 0)\n\n    dtype = 'float32' if normalized else 'uint8'\n    cmap = np.zeros((N, 3), dtype=dtype)\n    for i in range(N):\n        r = g = b = 0\n        c = i\n        for j in range(8):\n            r = r | (bitget(c, 0) << 7-j)\n            g = g | (bitget(c, 1) << 7-j)\n            b = b | (bitget(c, 2) << 7-j)\n            c = c >> 3\n\n        cmap[i] = np.array([r, g, b])\n\n    cmap = cmap/255 if normalized else cmap\n    return cmap\n\n\ndef save_mask(mask, img_path):\n    if np.max(mask) > 255:\n        raise ValueError('Maximum id pixel value is 255')\n    mask_img = Image.fromarray(mask.astype(np.uint8))\n    mask_img.putpalette(color_map().flatten().tolist())\n    mask_img.save(img_path)\n\n\ndef db_statistics(per_frame_values):\n    \"\"\" Compute mean,recall and decay from per-frame evaluation.\n    Arguments:\n        per_frame_values (ndarray): per-frame evaluation\n    Returns:\n        M,O,D (float,float,float):\n            return evaluation statistics: mean,recall,decay.\n    \"\"\"\n\n    # strip off nan values\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n        M = np.nanmean(per_frame_values)\n        O = np.nanmean(per_frame_values > 0.5)\n\n    N_bins = 4\n    ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1\n    ids = ids.astype(np.uint8)\n\n    D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)]\n\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n        D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3])\n\n    return M, O, D\n\n\ndef list_files(dir, extension=\".png\"):\n    return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)]\n\n\ndef force_symlink(file1, file2):\n    try:\n        os.symlink(file1, file2)\n    except OSError as e:\n        if e.errno == errno.EEXIST:\n            os.remove(file2)\n        os.symlink(file1, file2)\n"
  },
  {
    "path": "utils/palette.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport matplotlib.cm as cm\nimport numpy as np\nfrom PIL import ImagePalette\n\ndef colormap(N=256):\n    def bitget(byteval, idx):\n        return ((byteval & (1 << idx)) != 0)\n\n    dtype = 'uint8'\n    cmap = []\n    for i in range(N):\n        r = g = b = 0\n        c = i\n        for j in range(8):\n            r = r | (bitget(c, 0) << 7-j)\n            g = g | (bitget(c, 1) << 7-j)\n            b = b | (bitget(c, 2) << 7-j)\n            c = c >> 3\n\n        cmap.append((r, g, b))\n\n    return cmap\n\ndef apply_cmap(masks_pred, cmap):\n    canvas = np.zeros((masks_pred.shape[0], masks_pred.shape[1], 3))\n\n    for label in np.unique(masks_pred):\n        canvas[masks_pred == label, :] = cmap[label]\n\n    return canvas #np.transpose(canvas, [2,0,1])\n\n\ndef create_palette(colormap, num):\n\n    cmap = cm.get_cmap(colormap)\n    palette = ImagePalette.ImagePalette()\n\n    for n in range(num):\n        val = n / num\n        rgb = [int(255*x) for x in cmap(val)[:-1]]\n        palette.getcolor(tuple(rgb))\n\n    return palette\n\ndef custom_palette(nclasses, cname=\"rainbow\"):\n    cmap = cm.get_cmap(cname, nclasses)\n    return cmap\n"
  },
  {
    "path": "utils/palette_davis.py",
    "content": "palette_str = '''0 0 0\n128 0 0\n0 128 0\n128 128 0\n0 0 128\n128 0 128\n0 128 128\n128 128 128\n64 0 0\n191 0 0\n64 128 0\n191 128 0\n64 0 128\n191 0 128\n64 128 128\n191 128 128\n0 64 0\n128 64 0\n0 191 0\n128 191 0\n0 64 128\n128 64 128\n22 22 22\n23 23 23\n24 24 24\n25 25 25\n26 26 26\n27 27 27\n28 28 28\n29 29 29\n30 30 30\n31 31 31\n32 32 32\n33 33 33\n34 34 34\n35 35 35\n36 36 36\n37 37 37\n38 38 38\n39 39 39\n40 40 40\n41 41 41\n42 42 42\n43 43 43\n44 44 44\n45 45 45\n46 46 46\n47 47 47\n48 48 48\n49 49 49\n50 50 50\n51 51 51\n52 52 52\n53 53 53\n54 54 54\n55 55 55\n56 56 56\n57 57 57\n58 58 58\n59 59 59\n60 60 60\n61 61 61\n62 62 62\n63 63 63\n64 64 64\n65 65 65\n66 66 66\n67 67 67\n68 68 68\n69 69 69\n70 70 70\n71 71 71\n72 72 72\n73 73 73\n74 74 74\n75 75 75\n76 76 76\n77 77 77\n78 78 78\n79 79 79\n80 80 80\n81 81 81\n82 82 82\n83 83 83\n84 84 84\n85 85 85\n86 86 86\n87 87 87\n88 88 88\n89 89 89\n90 90 90\n91 91 91\n92 92 92\n93 93 93\n94 94 94\n95 95 95\n96 96 96\n97 97 97\n98 98 98\n99 99 99\n100 100 100\n101 101 101\n102 102 102\n103 103 103\n104 104 104\n105 105 105\n106 106 106\n107 107 107\n108 108 108\n109 109 109\n110 110 110\n111 111 111\n112 112 112\n113 113 113\n114 114 114\n115 115 115\n116 116 116\n117 117 117\n118 118 118\n119 119 119\n120 120 120\n121 121 121\n122 122 122\n123 123 123\n124 124 124\n125 125 125\n126 126 126\n127 127 127\n128 128 128\n129 129 129\n130 130 130\n131 131 131\n132 132 132\n133 133 133\n134 134 134\n135 135 135\n136 136 136\n137 137 137\n138 138 138\n139 139 139\n140 140 140\n141 141 141\n142 142 142\n143 143 143\n144 144 144\n145 145 145\n146 146 146\n147 147 147\n148 148 148\n149 149 149\n150 150 150\n151 151 151\n152 152 152\n153 153 153\n154 154 154\n155 155 155\n156 156 156\n157 157 157\n158 158 158\n159 159 159\n160 160 160\n161 161 161\n162 162 162\n163 163 163\n164 164 164\n165 165 165\n166 166 166\n167 167 167\n168 168 168\n169 169 169\n170 170 170\n171 171 171\n172 172 172\n173 173 173\n174 174 174\n175 175 175\n176 176 176\n177 177 177\n178 178 178\n179 179 179\n180 180 180\n181 181 181\n182 182 182\n183 183 183\n184 184 184\n185 185 185\n186 186 186\n187 187 187\n188 188 188\n189 189 189\n190 190 190\n191 191 191\n192 192 192\n193 193 193\n194 194 194\n195 195 195\n196 196 196\n197 197 197\n198 198 198\n199 199 199\n200 200 200\n201 201 201\n202 202 202\n203 203 203\n204 204 204\n205 205 205\n206 206 206\n207 207 207\n208 208 208\n209 209 209\n210 210 210\n211 211 211\n212 212 212\n213 213 213\n214 214 214\n215 215 215\n216 216 216\n217 217 217\n218 218 218\n219 219 219\n220 220 220\n221 221 221\n222 222 222\n223 223 223\n224 224 224\n225 225 225\n226 226 226\n227 227 227\n228 228 228\n229 229 229\n230 230 230\n231 231 231\n232 232 232\n233 233 233\n234 234 234\n235 235 235\n236 236 236\n237 237 237\n238 238 238\n239 239 239\n240 240 240\n241 241 241\n242 242 242\n243 243 243\n244 244 244\n245 245 245\n246 246 246\n247 247 247\n248 248 248\n249 249 249\n250 250 250\n251 251 251\n252 252 252\n253 253 253\n254 254 254\n255 255 255'''\nimport numpy as np\ntensor = np.array([[float(x)/255 for x in line.split()] for line in palette_str.split('\\n')])\n\nfrom matplotlib.colors import ListedColormap\npalette = ListedColormap(tensor, 'davis')\n"
  },
  {
    "path": "utils/stat_manager.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\n\nclass StatManager(object):\n\n    def __init__(self):\n        self.func_keys = {}\n        self.vals = {}\n        self.vals_count = {}\n        self.formats = {}\n\n    def items(self):\n        curr_vals = {}\n        for k in self.vals:\n            if self.has_vals(k):\n                curr_vals[k] = self.summarize_key(k)\n        return curr_vals.items()\n\n    def reset(self):\n        for k in self.vals:\n            self.vals[k] = 0.0\n            self.vals_count[k] = 0.0\n\n    def add_val(self, key, form=\"{:4.3f}\"):\n        self.vals[key] = 0.0\n        self.vals_count[key] = 0.0\n        self.formats[key] = form\n\n    def get_val(self, key):\n        return self.vals[key], self.vals_count[key]\n\n    def add_compute(self, key, func, form=\"{:4.3f}\"):\n        self.func_keys[key] = func\n        self.add_val(key)\n        self.formats[key] = form\n    \n    def update_stats(self, key, val, count = 1):\n        if not key in self.vals:\n            self.add_val(key)\n\n        self.vals[key] += val\n        self.vals_count[key] += count\n\n    def compute_stats(self, a, b, size = 1):\n\n        for k, func in self.func_keys.iteritems():\n            self.vals[k] += func(a, b)\n            self.vals_count[k] += size\n\n    def has_vals(self, k):\n        return k in self.vals_count and \\\n                self.vals_count[k] > 0\n\n    def summarize_key(self, k):\n        if self.has_vals(k) and abs(self.vals[k]) > 0.:\n            return self.vals[k] / self.vals_count[k]\n        else:\n            return 0\n\n    def summarize(self, epoch = 0, verbose = True):\n\n        if verbose:\n            out = \"\\tEpoch[{:03d}]\".format(epoch)\n            for k in self.vals:\n                if self.has_vals(k):\n                    out += (\" / {} \" + self.formats[k]).format(k, self.summarize_key(k))\n            print(out)\n\n"
  },
  {
    "path": "utils/sys_tools.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport os\n\ndef check_dir(base_path, name):\n    \"\"\"Make sure the directory exists\"\"\"\n\n    # create the directory\n    fullpath = os.path.join(base_path, name)\n    if not os.path.exists(fullpath):\n        os.makedirs(fullpath)\n\n    return fullpath\n"
  },
  {
    "path": "utils/timer.py",
    "content": "\"\"\"\nCopyright (c) 2021 TU Darmstadt\nAuthor: Nikita Araslanov <nikita.araslanov@tu-darmstadt.de>\nLicense: Apache License 2.0\n\"\"\"\n\nimport time\n\nclass Timer:\n    def __init__(self, starting_msg = None):\n        self.start = time.time()\n        self.stage_start = self.start\n\n        if starting_msg is not None:\n            print(starting_msg, time.ctime(time.time()))\n\n\n    def update_progress(self, progress):\n        self.elapsed = time.time() - self.start\n        self.est_total = self.elapsed / progress\n        self.est_remaining = self.est_total - self.elapsed\n        self.est_finish = int(self.start + self.est_total)\n\n    def stage(self, msg=None):\n        t = self.get_stage_elapsed()\n        self.reset_stage()\n        if not msg is None:\n            print(\"{:4.3f} elapsed: {}\".format(t, msg))\n        return t\n\n    def str_est_finish(self):\n        return str(time.ctime(self.est_finish))\n\n    def get_stage_elapsed(self):\n        return time.time() - self.stage_start\n\n    def reset_stage(self):\n        self.stage_start = time.time()\n"
  }
]