[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nwork_dir/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\ndata/\ndata\n.vscode\n.idea\n.DS_Store\n\n# custom\n*.pkl\n*.pkl.json\n*.log.json\n\n# Pytorch\n*.pth\n*.py~\n*.sh~\n\ndebug/*\nvis/\nanalysis/*\npretrain/*\n"
  },
  {
    "path": "DATASET.md",
    "content": "Please prepare the data structure as the following instruction:\n\nThe final dataset folder should be like this. \n```\nroot \n├── data\n│   ├──  kitti-step\n│   ├──  coco\n│   ├──  VIPSeg\n│   ├──  youtube_vis_2019\n│   ├──  cityscapes\n```\n\n### [VPS] KITTI-STEP\n\nDownload the KITTI-STEP from the official website. \n\nThen run the scripts in scripts/kitti_step_prepare.py.\nYou will get such format.\nYou can get the our pre-process format in https://huggingface.co/LXT/VideoK-Net/tree/main\n\n```\n├── kitti-step\n│   ├──  video_sequence\n│   │   ├── train\n            ├──00018_000331_leftImg8bit.png\n            ├──000018_000331_panoptic.png\n            ├──****\n│   │   ├── val\n│   │   ├── test \n```\n\n\n### [VPS] VIPSeg\n\nDownload the origin dataset from the official repo.\\\nFollowing official repo, we use resized videos for training and evaluation (The short size of the input is set to 720 while the ratio is keeped).\n\n```\n├── VIPSeg\n│   ├──  images\n│   │   ├── 1241_qYvEuwrSiXc\n        │      ├──*.jpg\n│   ├──  panomasks \n│   │   ├── 1241_qYvEuwrSiXc\n        │      ├──*.png\n│   ├──  panomasksRGB \n```\n\n\n### [VIS] Youtube-VIS-2019\nWe use pre-processed json file according to mmtracking codebase.\nsee the \"tools/dataset/youtubevis2coco.py\"\n\n```\n├── youtube_vis_2019\n│   ├── annotations\n│   │   ├── train.json\n│   │   ├── valid.json\n│   │   ├── youtube_vis_2019_train.json\n│   │   ├── youtube_vis_2019_valid.json\n│   ├── train\n│   │   ├──JPEGImages\n│   │   │   ├──video floders\n│   ├── valid\n│   │   ├──JPEGImages\n│   │   │   ├──video floders\n```\n\n\n### [VSS] VSPW\n\nTo do\n\n\n### [VPS] Cityscapes \n\nFor Cityscape-VPS and Cityscape-DVPS, we suggest the follower to see\nThe model of Video K-Net will not be released due to the Patent ISSUE and INTERNAL USEAGE. \n\nYou can find our related works. ECCV-2022, PolyphonicFormer: A Unified Framework For Panoptic Segmentation + Depth Estimation (winner of ICCV-2021 BMTT workshop)\n(https://github.com/HarborYuan/PolyphonicFormer)\n\n\n\n## Image DataSet For Pretraining K-Net\n\n### COCO dataset\n\nCOCO is most common datatsets. It contains 80 thing classes and 54 stuff classes.\n\nThe dataset format is the same as origin [Detectron2](https://github.com/facebookresearch/detectron2)\nincluding panoptic segmentation preparation [scirpts](https://github.com/facebookresearch/detectron2/blob/master/datasets/prepare_panoptic_fpn.py).\n\nThen the final folder is like this:\n```\n├── coco\n│   ├── annotations\n│   │   ├── panoptic_{train,val}2017.json\n│   │   ├── instance_{train,val}2017.json\n│   ├── train2017\n│   ├── val2017\n│   ├── panoptic_{train,val}2017/  # png annotations\n```\n\n### Cityscapes dataset\n\nCityscapes dataset is a high-resolution road-scene dataset which contains 19 classes. \n(8 thing classes and 11 stuff classes). 2975 images for training, 500 images for validation and 1525 images for testing.\n\nPreparing cityscape dataset has three steps:\n\n1, Convert segmentation id map(origin label id maps) to trainId maps (id ranges: 0-18 for training) using \nthe official scripts [repo](https://github.com/mcordts/cityscapesScripts)\n\n2, The run python dataset/prepare_cityscapes.py to generate the COCO-like annotations. \nThis annotations can be used for Instance Segmentation training.\n\nusing csCreateTrainIdLabelImgs.py\n\nand put the instancesonly_filtered_gtFine_train.json into annotations folder\n\n\n3, For Panoptic Segmenation dataset, to generate the json file \n\nusing csCreatePanopticImgs.py \n\nor you can download the our transformed .json and .png files via link: () and put the \njson file into annotations folder. \n\nThen the final folder is like this:\n\n```\n├── cityscapes\n│   ├── annotations\n│   │   ├── instancesonly_filtered_gtFine_train.json # coco instance annotation file(COCO format)\n│   │   ├── instancesonly_filtered_gtFine_val.json\n│   │   ├── cityscapes_panoptic_train.json  # panoptic json file \n│   │   ├── cityscapes_panoptic_val.json  \n│   ├── leftImg8bit\n│   ├── gtFine\n│   │   ├──cityscapes_panoptic_{train,val}/  # png annotations\n│   │   \n```\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Xiangtai  Lee\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Video K-Net: A Simple, Strong, and Unified Baseline for Video Segmentation (CVPR-2022, oral) \n## [Paper](https://arxiv.org/abs/2204.04656), [Sides](./slides/Video-KNet-cvpr-slides-10-25-version.pptx), [Poster](./slides/cvpr22_poster_lxt_zww_pjm.pdf), [Video](https://www.youtube.com/watch?v=LIEyp_czu20&t=3s)\n\n[Xiangtai Li](https://lxtgh.github.io/),\n[Wenwei Zhang](https://zhangwenwei.cn/),\n[Jiangmiao Pang](https://oceanpang.github.io/),\n[Kai Chen](https://chenkai.site/), \n[Guangliang Cheng](https://scholar.google.com/citations?user=FToOC-wAAAAJ),\n[Yunhai Tong](https://scholar.google.com/citations?user=T4gqdPkAAAAJ&hl=zh-CN),\n[Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/).\n\nWe introduce Video K-Net, a simple, strong, and unified framework for fully end-to-end dense video segmentation. \n\nThe method is built upon K-Net, a method of unifying image segmentation via a group of learnable kernels.\n\nThis project contains the training and testing code of Video K-Net for both VPS (Video Panoptic Segmentation), \nVSS(Video Semantic Segmentation), VIS(Video Instance Segmentation).\n\nTo the best of our knowledge, our Video K-Net is the first open-sourced method that supports three different video segmentation tasks (VIS, VPS, VSS) for Video Scene Understanding.\n\n## News! Video K-Net is acknowledged as a strong baseline for CVPR-2023 workshop [\"The 2nd Pixel-level Video Understanding in the Wild\"](https://www.vspwdataset.com/Workshop%202023.html). \n## News! Video K-Net also supports [VIP-Seg](https://github.com/VIPSeg-Dataset/VIPSeg-Dataset) dataset(CVPR-2022). It also achieves the new state-of-the-art result.\n\n\n### Environment and DataSet Preparation \nOur codebase is based on MMDetection and MMSegmentation. Parts of the code is borrowed from MMtracking and UniTrack.\n\n- MIM >= 0.1.1\n- MMCV-full >= v1.3.8\n- MMDetection == v2.18.0\n- timm\n- scipy\n- panopticapi\n\nSee the [DATASET.md](https://github.com/lxtGH/Video-K-Net/blob/main/DATASET.md)\n\nknet folder contains the Video K-Net for VPS.\n\nknet_vis folder contains the Video K-Net for VIS.\n\n\n\n### Pretrained CKPTs and Trained Models\n\nWe provide the pretrained models for VPS and VIS.\n\nBaidu Yun Link: [here](https://pan.baidu.com/s/12dIinkAF3o60fcAoggVhjQ)  Code:i034\n\nOne Drive Link: [here](https://1drv.ms/u/s!Ai4mxaXd6lVBgSCTUS0QWNim2zGx?e=uceSee)\n\nThe pretrained models are provided to train the Video K-Net.\n\nThe trained models are also provided for play and test.\n\n\n\n### [VPS] KITTI-STEP\n\n1. First pretrain K-Net on Cityscapes-STEP datasset. As shown in original STEP paper(Appendix Part) and our own EXP results, this step is very important to improve the segmentation performance.\nYou can also use our trained model for verification.\n\nCityscape-STEP follows the format of STEP: 17 stuff classes and 2 thing classes. \n\n```bash\n# train cityscapes step panoptic segmentation models\nsh ./tools/slurm_train.sh $PARTITION knet_step configs/det/knet_cityscapes_step/knet_s3_r50_fpn.py $WORK_DIR --no-validate\n```\n\n2. Then train the Video K-Net on KITTI-STEP. We have provided the pretrained models from Cityscapes of Video K-Net.\n\nFor slurm users:\n\n```bash\n# train Video K-Net on KITTI-step using R-50\nGPUS=8 sh ./tools/slurm_train.sh $PARTITION video_knet_step configs/det/video_knet_kitti_step/video_knet_s3_r50_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_train.py $WORK_DIR --no-validate --load-from /path_to_knet_step_city_r50\n```\n\n```bash\n# train Video K-Net on KITTI-step using Swin-base\nGPUS=16 GPUS_PER_NODE=8 sh ./tools/slurm_train.sh $PARTITION video_knet_step configs/det/video_knet_kitti_step/video_knet_s3_swinb_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_train.py $WORK_DIR --no-validate --load-from /path_to_knet_step_city_r50\n```\n\nOur models are trained with two V100 machines. \n\nFor Local machine:\n\n```bash\n# train Video K-Net on KITTI-step with 8 GPUs\nsh ./tools/dist_train.sh video_knet_step configs/det/video_knet_kitti_step/video_knet_s3_r50_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_train.py 8 $WORK_DIR --no-validate\n```\n\n\n3. Testing and Demo.\n\nWe provide both VPQ and STQ metrics to evaluate VPS models. \n\n```bash\n# test locally \nsh ./tools/dist_step_test.sh configs/det/knet_cityscapes_ste/knet_s3_r50_fpn.py $MODEL_DIR \n```\n\nWe also dump the colored images for debug.\n\n```bash\n# eval STEP STQ\npython tools/eval_dstq_step.py result_path gt_path\n```\n\n```bash\n# eval STEP VPQ\npython tools/eval_dvpq_step.py result_path gt_path\n```\n\n#### Toy Video K-Net \n\nAs shown in the paper, we also provide toy video K-Net in knet/video/knet_quansi_dense_embed_fc_toy_exp.py. \nYou use the K-Net pre-trained on image-level KITTI-STEP without tracking.\n\n\n### [VIS] YouTube-VIS-2019\n\n1. First Download the pre-trained Image K-Net instance segmentation models. All the models are pretrained on COCO which is\na common. You can also pretrain it by yourself. We also provide the config for pretraining.\n\nFor slurm users:\n\n```bash\n# train K-Net instance segmentation models on COCO using R-50\nGPUS=8 sh ./tools/slurm_train.sh $PARTITION knet_instance configs/det/coco/knet_s3_r50_fpn_ms-3x_coco.py $WORK_DIR \n```\n\n2. Then train the video K-Net in a clip-wised manner. \n\n```bash\n# train Video K-Net VIS models using R-50\nGPUS=8 sh ./tools/slurm_train.sh $PARTITION video_knet_vis configs/video_knet_vis/video_knet_vis/knet_track_r50_1x_youtubevis.py $WORK_DIR --load-from /path_to_knet_instance_coco\n```\n\n3. To evaluate the results of Video K-Net on VIS. Dump the prediction results for submission to the conda server. \n\n```bash\n# test Video K-Net VIS models using R-50\nGPUS=8 sh tools_vis/dist_test_whole_video.sh $PARTITION video_knet_vis configs/video_knet_vis/video_knet_vis/knet_track_r50_1x_youtubevis.py $WORK_DIR --format-only\n```\nThe result json is dumped into the root of this codebase. \n\n### [VPS] VIP-Seg\n\n1. First Download the pre-trained Image K-Net panoptic segmentation models. All the models are pretrained on COCO which is\na common step following VIP-Seg. You can also pretrain it by yourself. We also provide the config for pretraining.\n```bash\n# train K-Net on COCO Panoptic Segmetnation\nGPUS=8 sh ./tools/slurm_train.sh $PARTITION knet_coco configs/det/coco/knet_s3_r50_fpn_ms-3x_coco-panoptic.py $WORK_DIR \n```\n\n2. Train the Video K-Net on the VIP-Seg dataset. \n```bash\n# train Video K-Net on VIP-Seg\nGPUS=8 sh ./tools/slurm_train.sh $PARTITION video_knet_vis configs/det/video_knet_vipseg/video_knet_s3_r50_rpn_vipseg_mask_embed_link_ffn_joint_train.py $WORK_DIR --load-from /path/knet_coco_pretrained_r50\n```\n\n3. Test the Video K-Net on VIP-Seg val dataset.\n```bash\n# test locally on VIP-Seg\nsh ./tools/dist_step_test.sh configs/det/video_knet_vipseg/video_knet_s3_r50_rpn_vipseg_mask_embed_link_ffn_joint_train.py $MODEL_DIR \n```\n\nWe also dump the colored images for debug.\n\n```bash\n# eval STEP STQ\npython tools/eval_dstq_vipseg.py result_path gt_path\n```\n\n```bash\n# eval STEP VPQ\npython tools/eval_dvpq_vipseg.py result_path gt_path\n```\n\n\n## Visualization Results\n\n\n### Results on KITTI-STEP DataSet\n\n\n\n### Results on VIP-Seg DataSet\n\n\n\n### Results on YouTube-VIS DataSet\n\n\n\n### Short term segmentation and tracking results on Cityscapes VPS dataset.\n\nimages(left), Video K-Net(middle), Ground Truth \n![Alt Text](./figs/cityscapes_vps_video_1_20220318131729.gif)\n\n![Alt Text](./figs/cityscapes_vps_video_2_20220318132943.gif)\n\n### Long term segmentation and tracking results on STEP dataset.\n\n![Alt Text](./figs/step_video_1_20220318133227.gif)\n\n![Alt Text](./figs/step_video_2_20220318133423.gif)\n\n\n## Related Project and Acknowledgement\n## Citing Video K-Net :pray:\n\nIf you use our codebase in your research or used for CVPR-2023 pixel-level video workshop, please use the following BibTeX entry.\n\nNIPS-2021, K-Net: Unified Segmentation: Our Image baseline (https://github.com/ZwwWayne/K-Net)\n\nECCV-2022, PolyphonicFormer: A Unified Framework For Panoptic Segmentation + Depth Estimation (winner of ICCV-2021 BMTT workshop)\n(https://github.com/HarborYuan/PolyphonicFormer)\n\n```bibtex\n@inproceedings{li2022videoknet,\n  title={Video k-net: A simple, strong, and unified baseline for video segmentation},\n  author={Li, Xiangtai and Zhang, Wenwei and Pang, Jiangmiao and Chen, Kai and Cheng, Guangliang and Tong, Yunhai and Loy, Chen Change},\n  booktitle={CVPR},\n  year={2022}\n}\n\n@article{zhang2021k,\n  title={K-net: Towards unified image segmentation},\n  author={Zhang, Wenwei and Pang, Jiangmiao and Chen, Kai and Loy, Chen Change},\n  journal={NeurIPS},\n  year={2021}\n}\n```\n\n"
  },
  {
    "path": "configs/det/_base_/datasets/cityscapes_panoptic.py",
    "content": "# dataset settings\ndataset_type = 'CityscapesPanopticDataset'\ndata_root = 'data/cityscapes/'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(\n        type='Resize', img_scale=[(2048, 800), (2048, 1024)], multiscale_mode='range', keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\n\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 1024),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=8,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=dict(\n                ins_ann=data_root + 'annotations/instancesonly_filtered_gtFine_train.json',\n                panoptic_ann=data_root + 'annotations/cityscapes_panoptic_train.json'\n            ),\n            img_prefix=data_root + 'leftImg8bit/train/',\n            seg_prefix=data_root + 'gtFine/train',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root +'annotations/instancesonly_filtered_gtFine_val.json',\n            panoptic_ann=data_root + \"annotations/cityscapes_panoptic_val.json\"\n        ),\n        img_prefix=data_root + 'leftImg8bit/val/',\n        seg_prefix=data_root + 'gtFine/cityscapes_panoptic_val',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instancesonly_filtered_gtFine_val.json',\n            panoptic_ann=data_root + \"annotations/cityscapes_panoptic_val.json\"\n        ),\n        img_prefix=data_root + 'leftImg8bit/val/',\n        seg_prefix=data_root + 'gtFine/cityscapes_panoptic_val',\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['panoptic'])\n"
  },
  {
    "path": "configs/det/_base_/datasets/cityscapes_step.py",
    "content": "dataset_type = 'CityscapesSTEP'\ndata_root = 'data/cityscapes'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53],\n    std=[58.395, 57.12, 57.375],\n    to_rgb=True\n)\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotationsInstanceMasks', cherry=[11, 13]),\n    dict(type='KNetInsAdapterCherryPick', stuff_nums=11, cherry=[11, 13]),\n    dict(type='Resize', img_scale=(1024, 2048), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='RandomCrop', crop_size=(1024, 2048)),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='PadFutureMMDet', size_divisor=32, pad_val=dict(img=0, masks=0, seg=255)),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_masks', 'gt_labels', 'gt_semantic_seg'],\n         meta_keys=('ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                    'flip_direction', 'img_norm_cfg')\n         ),\n]\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg'\n                 ]),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=8,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            split='train',\n            test_mode=False,\n            pipeline=train_pipeline\n        )),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        test_mode=True,\n        pipeline=test_pipeline\n    ),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        test_mode=True,\n        pipeline=test_pipeline\n    )\n)\n\nevaluation = dict()\n"
  },
  {
    "path": "configs/det/_base_/datasets/cityscapes_vps_clips.py",
    "content": "dataset_type = 'CityscapesVPSDataset'\ndata_root = 'data/cityscapes_vps/'\ndataset_type_test = \"CityscapesPanopticDataset\"\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesFromFile'),\n    dict(type='SeqLoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(type='SeqResize', img_scale=[(512, 1024), (2048, 4096)], multiscale_mode='range', keep_ratio=True),\n    dict(type='SeqRandomFlip',  share_params=True, flip_ratio=0.5),\n    dict(type='SeqRandomCrop',  crop_size=(1024, 1024), share_params=True),\n    dict(type='SeqNormalize', **img_norm_cfg),\n    dict(type='SeqPad', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', \"gt_instance_ids\"]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\n\ntest_pipeline = [\n    dict(type='LoadRefImageFromFile'),\n\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=[(2048, 1024)],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img', 'ref_img']),\n            dict(type='Collect', keys=['img', 'ref_img']),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=8,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=dict(ins_ann=data_root +'instances_train_city_vps_rle.json',\n                          panoptic_ann=data_root + 'panoptic_im_train_city_vps.json'\n                          ),\n            img_prefix=data_root + 'train/img/',\n            seg_prefix=data_root + 'train/labelmap/',\n            pipeline=train_pipeline,\n            offsets=[-1,+1])),\n    val=dict(\n        type=dataset_type_test,\n        ann_file=dict(ins_ann=data_root + 'instances_val_city_vps_rle.json',\n                      panoptic_ann=data_root + 'panoptic_gt_val_city_vps.json',\n                      vps=True\n                      ),\n        img_prefix=data_root + 'val/img/',\n        seg_prefix=data_root + 'val/panoptic_video/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type_test,\n        ann_file=dict(ins_ann=data_root + 'instances_val_city_vps_rle.json',\n                      panoptic_ann=data_root + 'panoptic_gt_val_city_vps.json',\n                      vps=True\n                      ),\n        img_prefix=data_root + 'val/img_all/',     # img for validation\n        ref_prefix=data_root + 'val/img_all/',  # ref_images\n        nframes_span_test=30,\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['panoptic'])"
  },
  {
    "path": "configs/det/_base_/datasets/cityscapes_vps_clips_trainval.py",
    "content": "dataset_type = 'CityscapesVPSDataset'\ndata_root = 'data/cityscapes_vps/'\ndataset_type_test = \"CityscapesPanopticDataset\"\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesFromFile'),\n    dict(type='SeqLoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(type='SeqResize', img_scale=[(512, 1024), (2048, 4096)], multiscale_mode='range', keep_ratio=True),\n    dict(type='SeqRandomFlip',  share_params=True, flip_ratio=0.5),\n    dict(type='SeqRandomCrop',  crop_size=(1024, 2048), share_params=True),\n    dict(type='SeqNormalize', **img_norm_cfg),\n    dict(type='SeqPad', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', \"gt_instance_ids\"]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\n\ntest_pipeline = [\n    dict(type='LoadRefImageFromFile'),\n\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=[(2048, 1024)],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img', 'ref_img']),\n            dict(type='Collect', keys=['img', 'ref_img']),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=8,\n        dataset=dict(\n            type='ConcatDataset',\n            separate_eval=False,\n            datasets=[\n                dict(\n                    type=dataset_type,\n                    ann_file=dict(ins_ann=data_root +'instances_train_city_vps_rle.json',\n                                  panoptic_ann=data_root + 'panoptic_im_train_city_vps.json'\n                                  ),\n                    img_prefix=data_root + 'train/img/',\n                    seg_prefix=data_root + 'train/labelmap/',\n                    pipeline=train_pipeline,\n                    offsets=[-1,+1]\n                ),\n            dict(\n                type=dataset_type,\n                ann_file=dict(ins_ann=data_root +'instances_val_city_vps_rle.json',\n                              panoptic_ann=data_root + 'panoptic_gt_val_city_vps.json'\n                              ),\n                img_prefix=data_root + 'val/img/',\n                seg_prefix=data_root + 'val/labelmap/',\n                pipeline=train_pipeline,\n                offsets=[-1,+1]),\n            ],\n        )\n    ),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(ins_ann=data_root + 'instances_val_city_vps_rle.json',\n                      panoptic_ann=data_root + 'panoptic_gt_val_city_vps.json',\n                      vps=True\n                      ),\n        img_prefix=data_root + 'val/img_all/',     # img for validation\n        ref_prefix=data_root + 'val/img_all/',  # ref_images\n        nframes_span_test=30,\n        pipeline=test_pipeline)\n\n)\n\nevaluation = dict(metric=['panoptic'])"
  },
  {
    "path": "configs/det/_base_/datasets/coco_instance.py",
    "content": "dataset_type = 'CocoDataset'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_train2017.json',\n        img_prefix=data_root + 'train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\n# we do not evaluate bbox because K-Net does not predict bounding boxes\nevaluation = dict(metric=['segm'])\n"
  },
  {
    "path": "configs/det/_base_/datasets/coco_panoptic.py",
    "content": "dataset_type = 'CocoPanopticDatasetCustom'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(\n        type='Collect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip', flip_ratio=0.5),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_train2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_train2017.json'),\n        img_prefix=data_root + 'train2017/',\n        seg_prefix=data_root + 'panoptic_stuff_train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['segm', 'panoptic'])\n"
  },
  {
    "path": "configs/det/_base_/datasets/coco_panoptic_instance_annotations.py",
    "content": "dataset_type = 'CocoPanopticDatasetCustom'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(\n        type='Collect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip', flip_ratio=0.5),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_train2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_train2017.json'),\n        img_prefix=data_root + 'train2017/',\n        seg_prefix=data_root + 'panoptic_stuff_train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['segm', 'panoptic'])\n"
  },
  {
    "path": "configs/det/_base_/datasets/kitti_step_dvps.py",
    "content": "dataset_type = 'KITTISTEPDVPSDataset'\ndata_root = 'data/kitti-step'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\n# The kitti dataset contains 1226 x 370 and 1241 x 376\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=True, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    # dict(type='SeqResizeWithDepth', img_scale=(370, 1226), ratio_range=[1.0, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    # dict(type='SeqRandomCropWithDepth', crop_size=(352, 1024), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_depth', 'gt_instance_ids', ]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename'\n                 ]),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=4,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            split='train',\n            ref_seq_index=None,\n            test_mode=False,\n            pipeline=train_pipeline,\n            with_depth=True,\n        )),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        with_depth=True,\n    ),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        with_depth=True,\n    )\n)\n\nevaluation = dict()\n"
  },
  {
    "path": "configs/det/_base_/datasets/kitti_step_vps.py",
    "content": "dataset_type = 'KITTISTEPDVPSDataset'\ndata_root = 'data/kitti-step'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\n# The kitti dataset contains 1226 x 370 and 1241 x 376\n# 384 x 1248 is the minimum size that is 32-divisible\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids']),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename', \"filename\"\n                 ]),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=4,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            split='train',\n            ref_seq_index=None,\n            test_mode=False,\n            pipeline=train_pipeline,\n            with_depth=False,\n        )),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        with_depth=False,\n    ),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        with_depth=False,\n    )\n)\n\nevaluation = dict()\n"
  },
  {
    "path": "configs/det/_base_/datasets/kitti_step_vps_trainval.py",
    "content": "dataset_type = 'KITTISTEPDVPSDataset'\ndata_root = 'data/kitti-step'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\n# The kitti dataset contains 1226 x 370 and 1241 x 376\n# 384 x 1248 is the minimum size that is 32-divisible\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids']),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename', \"filename\"\n                 ]),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=4,\n        dataset=dict(\n            type='ConcatDataset',\n            separate_eval=False,\n            datasets=[\n                dict(\n                    type=dataset_type,\n                    data_root=data_root,\n                    split='train',\n                    ref_seq_index=None,\n                    test_mode=False,\n                    pipeline=train_pipeline,\n                    with_depth=False,\n                ),\n                dict(\n                    type=dataset_type,\n                    data_root=data_root,\n                    split='val',\n                    ref_seq_index=None,\n                    test_mode=False,\n                    pipeline=train_pipeline,\n                    with_depth=False,\n                )\n            ]\n        ),\n    ),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        with_depth=False,\n    ),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        with_depth=False,\n    )\n)\n\nevaluation = dict()\n"
  },
  {
    "path": "configs/det/_base_/datasets/mapillary_panoptic.py",
    "content": "dataset_type = 'MapillaryPanopticDataset'\ndata_root = 'data/mapillary/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(type='Resize', img_scale=[(1024, 4096), (2048, 4096)], multiscale_mode='range', keep_ratio=True),\n    dict(type='RandomCrop', crop_size=(1024, 1024)),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(\n        type='Collect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 4096),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip', flip_ratio=0.5),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/coco/training.json',\n            panoptic_ann=data_root + 'annotations/panoptic_train.json'\n        ),\n        img_prefix=data_root + 'training/images',\n        seg_prefix=data_root + 'training/panoptic_stuff_train',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/coco/validation.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val.json'),\n        seg_prefix=data_root + 'validation/panoptic',\n        img_prefix=data_root + 'validation/images',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/coco/validation.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val.json'),\n        seg_prefix=data_root + 'validation/panoptic',\n        img_prefix=data_root + 'validation/images',\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['segm', 'panoptic'])\n"
  },
  {
    "path": "configs/det/_base_/datasets/vipseg_dvps.py",
    "content": "dataset_type = 'VIPSegDVPSDataset'\ndata_root = 'data/VIPSeg'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\ncrop_size = (736, 736)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, vipseg=True),\n    dict(type='SeqResizeWithDepth', img_scale=(720, 100000), ratio_range=[1., 2.], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(736, 736), share_params=True),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(type='SeqNormalize', **img_norm_cfg),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids']),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img', 'img_id', 'seq_id'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename', \"filename\"\n                 ]),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=1,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            test_mode=False,\n            split='train',\n            ref_seq_index=[-2, -1, 1, 2],\n            is_instance_only=True,\n            pipeline=train_pipeline,\n        )),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n    ),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        split='val',\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n    )\n)\n\nevaluation = dict()\n"
  },
  {
    "path": "configs/det/_base_/default_runtime.py",
    "content": "checkpoint_config = dict(interval=1)\n# yapf:disable\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n    ])\n# yapf:enable\n\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]\n"
  },
  {
    "path": "configs/det/_base_/models/knet_citystep_s3_r50_fpn.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\n\nmodel = dict(\n    type='KNet',\n    cityscapes=False,\n    kitti_step=True,\n    num_thing_classes=2,\n    num_stuff_classes=17,\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4\n    ),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        num_classes=19,\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        cat_stuff_mask=True,\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        feat_transform_cfg=None,\n        loss_rank=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=0.1),\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        do_panoptic=True,\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=19,\n                num_thing_classes=2,\n                num_stuff_classes=17,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'),\n                    act_cfg=None\n                ),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_rank=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=0.1\n                ),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0\n                ),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0\n                ),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0))\n            for _ in range(num_stages)\n        ]\n    ),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1)\n\n            for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                iou_thr=0.5,\n                stuff_max_area=4096,\n                instance_score_thr=0.25\n            )\n        )\n    )\n)\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'knet.kernel_updator',\n        'knet.cross_entropy_loss',\n        'swin.swin_transformer',\n        'swin.mix_transformer',\n        'swin.DetectRS',\n        'swin.swin_transformer_rfp',\n        'external.cityscapes_step',\n        'external.dataset.pipelines.transforms',\n        'external.dataset.pipelines.loading',\n    ],\n    allow_failed_imports=False\n)\n"
  },
  {
    "path": "configs/det/_base_/models/knet_kitti_step_s3_r50_fpn.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\n\nmodel = dict(\n    type='KNet',\n    cityscapes=False,\n    kitti_step=True,\n    num_thing_classes=2,\n    num_stuff_classes=17,\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4\n    ),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        num_classes=19,\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        cat_stuff_mask=True,\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        feat_transform_cfg=None,\n        loss_rank=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=0.1),\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        do_panoptic=True,\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=19,\n                num_thing_classes=2,\n                num_stuff_classes=17,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'),\n                    act_cfg=None\n                ),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_rank=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=0.1\n                ),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0\n                ),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0\n                ),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0))\n            for _ in range(num_stages)\n        ]\n    ),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1)\n\n            for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5,\n                stuff_max_area=4096,\n                instance_score_thr=0.25\n            )\n        )\n    )\n)\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'knet.kernel_updator',\n        'knet.cross_entropy_loss',\n        'swin.swin_transformer',\n        'swin.mix_transformer',\n        'swin.DetectRS',\n        'swin.swin_transformer_rfp',\n        'external.cityscapes_step',\n        'external.kitti_step_dvps',\n        'external.dataset.dvps_pipelines.transforms',\n        'external.dataset.dvps_pipelines.loading',\n        'external.dataset.dvps_pipelines.tricks',\n        'external.dataset.pipelines.formatting',\n        # 'knet.video.knet_track',\n        # 'knet.video.knet_track_head',\n        'knet.video.track_heads',\n        'knet.video.kernel_head',\n        'knet.video.kernel_iter_head',\n        'knet.video.kernel_update_head',\n        'knet.video.knet_uni_track',\n        'knet.video.knet_quansi_dense',\n        # 'knet.video.knet_quansi_dense_roi',\n        'knet.video.knet_quansi_dense_roi_gt_box',\n        'knet.video.knet_quansi_dense_embed_fc',\n        'knet.video.knet_quansi_dense_embed_fc_joint_train',\n        # 'knet.video.knet_quansi_dense_embed_fc_with_appearance',\n        'knet.video.knet_quansi_dense_roi_gt_box_joint_train',\n        # 'knet.video.knet_quansi_dense_embed_fc_toy_exp',\n        'knet.video.qdtrack.losses.l2_loss',\n        'knet.video.qdtrack.losses.multipos_cross_entropy_loss',\n        'knet.video.qdtrack.trackers.quasi_dense_embed_tracker',\n    ],\n    allow_failed_imports=False\n)\n"
  },
  {
    "path": "configs/det/_base_/models/knet_s3_r50_deformable_fpn.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='MSDeformAttnPixelDecoder',\n        num_outs=3,\n        norm_cfg=dict(type='GN', num_groups=32),\n        act_cfg=dict(type='ReLU'),\n        return_one_list=True,\n        encoder=dict(\n            type='DetrTransformerEncoder',\n            num_layers=6,\n            transformerlayers=dict(\n                type='BaseTransformerLayer',\n                attn_cfgs=dict(\n                    type='MultiScaleDeformableAttention',\n                    embed_dims=256,\n                    num_heads=8,\n                    num_levels=3,\n                    num_points=4,\n                    im2col_step=64,\n                    dropout=0.0,\n                    batch_first=False,\n                    norm_cfg=None,\n                    init_cfg=None),\n                ffn_cfgs=dict(\n                    type='FFN',\n                    embed_dims=256,\n                    feedforward_channels=1024,\n                    num_fcs=2,\n                    ffn_drop=0.0,\n                    act_cfg=dict(type='ReLU', inplace=True)),\n                operation_order=('self_attn', 'norm', 'ffn', 'norm')),\n            init_cfg=None),\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True),\n        init_cfg=None),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        num_classes=80,\n        feat_transform_cfg=None,\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=80,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)) for _ in range(num_stages)\n        ]),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1) for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            merge_stuff_thing=dict(\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3))))\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.kernel_updator',\n        'knet.det.msdeformattn_decoder',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'external.coco_panoptic',\n        'swin.swin_transformer'\n    ],\n    allow_failed_imports=False)\n"
  },
  {
    "path": "configs/det/_base_/models/knet_s3_r50_fpn.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        num_classes=80,\n        feat_transform_cfg=None,\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=80,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)) for _ in range(num_stages)\n        ]),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1) for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            merge_stuff_thing=dict(\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3))))\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.knet',\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.kernel_updator',\n        'knet.det.msdeformattn_decoder',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'panoptic_fpn.coco_panoptic',\n    ],\n    allow_failed_imports=False)\n"
  },
  {
    "path": "configs/det/_base_/models/knet_s3_r50_fpn_panoptic.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        num_classes=133,  # modified for panoptic\n        cat_stuff_mask=True,  # modified for panoptic\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        feat_transform_cfg=None,\n        loss_rank=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=0.1),\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        do_panoptic=True,\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=133,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_rank=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=0.1),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)) for _ in range(num_stages)\n        ]),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1) for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3))))\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.kernel_updator',\n        'knet.cross_entropy_loss',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'swin.swin_transformer',\n        'external.mot_step',\n        'swin.mix_transformer',\n        'swin.DetectRS',\n        'swin.swin_transformer_rfp',\n        'external.coco_panoptic',\n        'external.mapillary_panoptic',\n        'external.cityscape_panoptic',\n        'external.kitti_step_dvps',\n        'external.mot_step',\n        'external.dataset.dvps_pipelines.transforms',\n        'external.dataset.dvps_pipelines.loading',\n        'external.dataset.dvps_pipelines.tricks',\n        'external.dataset.pipelines.formatting',\n    ],\n    allow_failed_imports=False)\n"
  },
  {
    "path": "configs/det/_base_/models/knet_vipseg_s3_r50_fpn.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\n\nnum_thing_classes = 58\nnum_stuff_classes = 66\nnum_classes = num_stuff_classes + num_thing_classes\n\nmodel = dict(\n    type='KNet',\n    cityscapes=False,\n    kitti_step=True,\n    num_thing_classes=num_thing_classes,\n    num_stuff_classes=num_stuff_classes,\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4\n    ),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        num_classes=num_classes,\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n        cat_stuff_mask=True,\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        feat_transform_cfg=None,\n        loss_rank=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=0.1),\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n        do_panoptic=True,\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=num_classes,\n                num_thing_classes=num_thing_classes,\n                num_stuff_classes=num_stuff_classes,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'),\n                    act_cfg=None\n                ),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_rank=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=0.1\n                ),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0\n                ),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0\n                ),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0))\n            for _ in range(num_stages)\n        ]\n    ),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1)\n\n            for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5,\n                stuff_max_area=4096,\n                instance_score_thr=0.25\n            )\n        )\n    )\n)\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.knet',\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'knet.kernel_updator',\n        'knet.cross_entropy_loss',\n        'swin.swin_transformer',\n        'swin.mix_transformer',\n        'swin.DetectRS',\n        'swin.swin_transformer_rfp',\n        'external.cityscapes_step',\n        'external.kitti_step_dvps',\n        'external.vipseg_dvps',\n        'external.dataset.dvps_pipelines.transforms',\n        'external.dataset.dvps_pipelines.loading',\n        'external.dataset.dvps_pipelines.tricks',\n        'external.dataset.pipelines.formatting',\n        'external.dataset.pipelines.transforms',\n        'knet.video.knet',\n        'knet.video.knet_quansi_dense',\n        'knet.video.knet_quansi_dense_roi_gt_box',\n        # 'knet.video.knet_track',\n        # 'knet.video.knet_track_head',\n        'knet.video.track_heads',\n        'knet.video.kernel_head',\n        'knet.video.kernel_iter_head',\n        'knet.video.kernel_update_head',\n        'knet.video.knet_uni_track',\n        'knet.video.knet_quansi_dense',\n        'knet.video.knet_quansi_dense_roi_gt_box',\n        'knet.video.knet_quansi_dense_embed_fc',\n        'knet.video.knet_quansi_dense_embed_fc_joint_train',\n        'knet.video.qdtrack.losses.l2_loss',\n        'knet.video.qdtrack.losses.multipos_cross_entropy_loss',\n        'knet.video.qdtrack.trackers.quasi_dense_embed_tracker',\n\n    ],\n    allow_failed_imports=False\n)\n"
  },
  {
    "path": "configs/det/_base_/models/video_knet_s3_r50_fpn_panoptic.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='VideoKNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4),\n    rpn_head=dict(\n        type='VideoConvKernelHead',\n        num_classes=133,  # modified for panoptic\n        cat_stuff_mask=True,  # modified for panoptic\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        feat_transform_cfg=None,\n        loss_rank=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=0.1),\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        do_panoptic=True,\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=133,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_rank=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=0.1),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)) for _ in range(num_stages)\n        ]),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1) for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3))))\n\ncustom_imports = dict(\n    imports=[\n        'knet.det.kernel_head',\n        'knet.det.kernel_iter_head',\n        'knet.det.kernel_update_head',\n        'knet.det.semantic_fpn_wrapper',\n        'knet.det.dice_loss',\n        'knet.cross_entropy_loss',\n        'knet.kernel_updator',\n        'knet.det.mask_hungarian_assigner',\n        'knet.det.mask_pseudo_sampler',\n        'external.coco_panoptic',\n        'external.youtubevis_clips',\n        'external.cityscapes_vps',\n        'external.cityscape_panoptic',\n        'external.cityscapes_dvps',\n        'swin.swin_transformer',\n        'swin.mix_transformer',\n        'swin.DetectRS',\n        'swin.swin_transformer_rfp',\n        # 'knet.video.knet_track',\n        # 'knet.video.knet_track_head',\n        'knet.video.track_heads',\n        'knet.video.kernel_head',\n        'knet.video.kernel_iter_head',\n        'knet.video.kernel_update_head',\n        'knet.video.knet_uni_track',\n        'knet.video.knet_quansi_dense',\n        'knet.video.knet_quansi_dense_conv_mask',\n        'knet.video.knet_quansi_dense_roi_gt_box',\n        'knet.video.knet_quansi_dense_embed_fc',\n        # 'knet.video.knet_quansi_dense_embed_fc_joint_train',\n        'knet.video.knet_quansi_dense_roi_gt_box_joint_train',\n        'knet.video.qdtrack.losses.l2_loss',\n        'knet.video.qdtrack.losses.multipos_cross_entropy_loss',\n        'knet.video.qdtrack.trackers.quasi_dense_embed_tracker',\n\n        'knet.video.knet_quansi_dense_embed_fc_toy_exp',\n        'external.ext.ytvos',\n        'external.ext.mask',\n\n        'external.dataset.pipelines.transforms',\n        'external.dataset.pipelines.loading',\n        'external.dataset.pipelines.formatting',\n\n        'external.dataset.dvps_pipelines.transforms',\n        'external.dataset.dvps_pipelines.loading',\n        'external.dataset.dvps_pipelines.tricks',\n        'external.dataset.pipelines.formatting',\n    ],\n    allow_failed_imports=False)\n"
  },
  {
    "path": "configs/det/_base_/schedules/schedule_10e.py",
    "content": "# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[8,])\nrunner = dict(type='EpochBasedRunner', max_epochs=10)\n"
  },
  {
    "path": "configs/det/_base_/schedules/schedule_1x.py",
    "content": "# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[8, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "configs/det/coco/knet_s3_r50_deformable_fpn_ms-3x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/knet_s3_r50_deformable_fpn.py',\n    '../common/mstrain_3x_coco_instance.py'\n]\n\nmodel = dict(\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='SyncBN', requires_grad=True),\n        norm_eval=True,),\n\n)"
  },
  {
    "path": "configs/det/coco/knet_s3_r50_fpn_ms-3x_coco-panoptic.py",
    "content": "_base_ = [\n    '../_base_/models/knet_s3_r50_fpn_panoptic.py',\n    '../common/mstrain_3x_coco_panoptic.py'\n]\nnum_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4),\n    rpn_head=dict(\n        type='ConvKernelHead',\n        num_classes=133,  # modified for panoptic\n        cat_stuff_mask=True,  # modified for panoptic\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        feat_transform_cfg=None,\n        loss_rank=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=0.1),\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHead',\n        do_panoptic=True,\n        merge_joint=True,\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=133,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_rank=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=0.1),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)) for _ in range(num_stages)\n        ]),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3)))\n)"
  },
  {
    "path": "configs/det/coco/knet_s3_r50_fpn_ms-3x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/knet_s3_r50_fpn.py',\n    '../common/mstrain_3x_coco_instance.py'\n]\n"
  },
  {
    "path": "configs/det/coco/knet_s3_swin-b_deformable_fpn_ms-3x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/knet_s3_r50_deformable_fpn.py',\n    '../common/mstrain_3x_coco_instance.py'\n]\n\nmodel = dict(\n    pretrained='/mnt/lustre/lixiangtai/pretrained/swin/swin_base_patch4_window7_224_22k.pth',\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=128,\n        depths=[2, 2, 18, 2],\n        num_heads=[4, 8, 16, 32],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.3,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False),\n    neck=dict(in_channels=[128, 256, 512, 1024])\n)\n"
  },
  {
    "path": "configs/det/common/lsj_coco_panoptic_50e.py",
    "content": "_base_ = '../_base_/default_runtime.py'\n# dataset settings\ndataset_type = 'CocoPanopticDatasetCustom'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\nimage_size = (1024, 1024)\n\n# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)],\n# multiscale_mode='range'\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(\n        type='Resize',\n        img_scale=image_size,\n        ratio_range=(0.1, 2.0),\n        multiscale_mode='range',\n        keep_ratio=True),\n    dict(\n        type='RandomCrop',\n        crop_type='absolute_range',\n        crop_size=image_size,\n        recompute_bbox=True,\n        allow_negative_crop=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(\n        type='Collect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\n# Use RepeatDataset to speed up training\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=1,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=dict(\n                ins_ann=data_root + 'annotations/panoptic_train2017_thing_only_coco.json',\n                panoptic_ann=data_root + 'annotations/panoptic_train2017.json'),\n            img_prefix=data_root + 'train2017/',\n            seg_prefix=data_root + 'panoptic_stuff_train2017/',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['segm', 'panoptic'], interval=5)\n\ncheckpoint_config = dict(interval=5)\n\n# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n\n# learning policy\n# Experiments show that using step=[9, 11] has higher performance\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[42, 48])\nrunner = dict(type='EpochBasedRunner', max_epochs=50)\n"
  },
  {
    "path": "configs/det/common/mstrain_3x_coco_instance.py",
    "content": "_base_ = '../_base_/default_runtime.py'\n# dataset settings\ndataset_type = 'CocoDataset'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\n# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)],\n# multiscale_mode='range'\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(\n        type='Resize',\n        img_scale=[(1333, 640), (1333, 800)],\n        multiscale_mode='range',\n        keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\n# Use RepeatDataset to speed up training\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=3,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=data_root + 'annotations/instances_train2017.json',\n            img_prefix=data_root + 'train2017/',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(interval=1, metric=['segm'])\n\n# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n\n# learning policy\n# Experiments show that using step=[9, 11] has higher performance\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "configs/det/common/mstrain_3x_coco_panoptic_inst_anno.py",
    "content": "_base_ = '../_base_/default_runtime.py'\n# dataset settings\ndataset_type = 'CocoPanopticDatasetCustom'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\n# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)],\n# multiscale_mode='range'\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(\n        type='Resize',\n        img_scale=[(1333, 640), (1333, 800)],\n        multiscale_mode='range',\n        keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(\n        type='Collect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\n# Use RepeatDataset to speed up training\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=3,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=dict(\n                ins_ann=data_root + 'annotations/panoptic_train2017_thing_only_coco.json',\n                panoptic_ann=data_root + 'annotations/panoptic_train2017.json'),\n            img_prefix=data_root + 'train2017/',\n            seg_prefix=data_root + 'panoptic_stuff_train2017/',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(metric=['segm', 'panoptic'])\n\n# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n\n# learning policy\n# Experiments show that using step=[9, 11] has higher performance\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "configs/det/common/mstrain_3x_coco_panoptic_inst_anno_detr_aug.py",
    "content": "_base_ = '../_base_/default_runtime.py'\n# dataset settings\ndataset_type = 'CocoPanopticDatasetCustom'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\n# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)],\n# multiscale_mode='range'\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(\n        type='AutoAugment',\n        policies=[[\n            dict(\n                type='Resize',\n                img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),\n                           (608, 1333), (640, 1333), (672, 1333), (704, 1333),\n                           (736, 1333), (768, 1333), (800, 1333)],\n                multiscale_mode='value',\n                keep_ratio=True)\n        ],\n          [\n              dict(\n                  type='Resize',\n                  img_scale=[(400, 1333), (500, 1333), (600, 1333)],\n                  multiscale_mode='value',\n                  keep_ratio=True),\n              dict(\n                  type='RandomCrop',\n                  crop_type='relative',\n                  crop_size=(0.7, 0.7),\n                  allow_negative_crop=True),\n              dict(\n                  type='Resize',\n                  img_scale=[(480, 1333), (512, 1333), (544, 1333),\n                             (576, 1333), (608, 1333), (640, 1333),\n                             (672, 1333), (704, 1333), (736, 1333),\n                             (768, 1333), (800, 1333)],\n                  multiscale_mode='value',\n                  override=True,\n                  keep_ratio=True)\n          ]]),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(\n        type='Collect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\n# Use RepeatDataset to speed up training\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=3,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=dict(\n                ins_ann=data_root + 'annotations/panoptic_train2017_thing_only_coco.json',\n                panoptic_ann=data_root + 'annotations/panoptic_train2017.json'),\n            img_prefix=data_root + 'train2017/',\n            seg_prefix=data_root + 'panoptic_stuff_train2017/',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instances_val2017.json',\n            panoptic_ann=data_root + 'annotations/panoptic_val2017.json'),\n        seg_prefix=data_root + 'panoptic_val2017/',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(metric=['segm', 'panoptic'])\n\n# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n\n# learning policy\n# Experiments show that using step=[9, 11] has higher performance\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "configs/det/common/mstrain_64e_city_panoptic.py",
    "content": "_base_ = '../_base_/default_runtime.py'\n# dataset settings\ndataset_type = 'CityscapesPanopticDataset'\ndata_root = 'data/cityscapes/'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),\n    dict(\n        type='Resize', img_scale=[(512, 1024), (2048, 4096)], multiscale_mode='range', keep_ratio=True),\n    dict(type='RandomCrop', crop_size=(1024, 2048)),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n]\n\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 1024),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=8,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=dict(\n                ins_ann=data_root + 'annotations/instancesonly_filtered_gtFine_train.json',\n                panoptic_ann=data_root + 'annotations/cityscapes_panoptic_train.json'\n            ),\n            img_prefix=data_root + 'leftImg8bit/train/',\n            seg_prefix=data_root + 'gtFine/train',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root +'annotations/instancesonly_filtered_gtFine_val.json',\n            panoptic_ann=data_root + \"annotations/cityscapes_panoptic_val.json\"\n        ),\n        img_prefix=data_root + 'leftImg8bit/val/',\n        seg_prefix=data_root + 'gtFine/cityscapes_panoptic_val',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=dict(\n            ins_ann=data_root + 'annotations/instancesonly_filtered_gtFine_val.json',\n            panoptic_ann=data_root + \"annotations/cityscapes_panoptic_val.json\"\n        ),\n        img_prefix=data_root + 'leftImg8bit/val/',\n        seg_prefix=data_root + 'gtFine/cityscapes_panoptic_val',\n        pipeline=test_pipeline))\n\nevaluation = dict(metric=['panoptic'])\n\n# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n\n# learning policy\n# Experiments show that using step=[9, 11] has higher performance\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=500,\n    warmup_ratio=0.001,\n    # [7] yields higher performance than [6]\n    step=[7])\nrunner = dict(\n    type='EpochBasedRunner', max_epochs=8)  # actual epoch = 8 * 8 = 64\n"
  },
  {
    "path": "configs/det/knet_cityscapes_step/knet_s3_r50_fpn.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_citystep_s3_r50_fpn.py',\n    '../_base_/datasets/cityscapes_step.py',\n]\n\n\nnum_proposals = 100\n# load_from = \"/mnt/lustre/lixiangtai/pretrained/video_knet_vis/knet_r50_city.pth\"\nload_from = None\n\nwork_dir = 'logger/blackhole'\n\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='SyncBN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    roi_head=dict(\n            type='KernelIterHead',\n            merge_joint=True,),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3)))\n)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[7, ],\n)\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n)\n"
  },
  {
    "path": "configs/det/knet_cityscapes_step/knet_s3_swin_b_fpn.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_citystep_s3_r50_fpn.py',\n    '../_base_/datasets/cityscapes_step.py',\n]\n\n\nnum_proposals = 100\n# load_from = \"/mnt/lustre/lixiangtai/pretrained/video_knet_vis/knet_swin_b_city.pth\"\nload_from = None\n\nwork_dir = 'logger/blackhole'\n\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=128,\n        depths=[2, 2, 18, 2],\n        num_heads=[4, 8, 16, 32],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.3,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False),\n    neck=dict(in_channels=[128, 256, 512, 1024]),\n    roi_head=dict(\n        type='KernelIterHead',\n        merge_joint=True,\n    ),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3)))\n)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[7, ],\n)\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n)\n"
  },
  {
    "path": "configs/det/knet_cityscapes_step/knet_s3_swin_l_fpn.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_citystep_s3_r50_fpn.py',\n    '../_base_/datasets/cityscapes_step.py',\n]\n\n\nnum_proposals = 100\n# load_from = \"/mnt/lustre/lixiangtai/pretrained/video_knet_vis/knet_swin_l_city.pth\"\nload_from = None\n\nwork_dir = 'logger/blackhole'\n\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n\nmodel = dict(\n    type='KNet',\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=192,\n        depths=[2, 2, 18, 2],\n        num_heads=[6, 12, 24, 48],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.2,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False),\n    neck=dict(in_channels=[192, 384, 768, 1536]),\n    roi_head=dict(\n        type='KernelIterHead',\n        merge_joint=True,\n    ),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=num_proposals,\n            mask_thr=0.5,\n            stuff_score_thr=0.05,\n            merge_stuff_thing=dict(\n                overlap_thr=0.6,\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3)))\n)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[7, ],\n)\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n)\n"
  },
  {
    "path": "configs/det/video_knet_kitti_step/video_knet_s3_r50_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_train.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_kitti_step_s3_r50_fpn.py',\n    '../_base_/datasets/kitti_step_vps.py',\n]\n\nload_from = None\n\nnum_stages = 3\nconv_kernel_size = 1\nnum_thing_classes = 2\nnum_stuff_classes = 17\nnum_classes = num_thing_classes + num_stuff_classes\n\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    cityscapes=False,\n    kitti_step=True,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=2,\n    num_stuff_classes=17,\n    ignore_label=255,\n    backbone=dict(\n            type='ResNet',\n            depth=50,\n            num_stages=4,\n            out_indices=(0, 1, 2, 3),\n            frozen_stages=1,\n            norm_cfg=dict(type='SyncBN', requires_grad=True),\n    ),\n    rpn_head=dict(\n        loss_seg=dict(\n                _delete_=True,\n                type='CrossEntropyLoss',\n                use_sigmoid=False,\n                loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add video_knet_vis roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=2,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=19,\n                previous='placeholder',\n                previous_type=\"ffn\",\n                num_thing_classes=2,\n                num_stuff_classes=17,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids',]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img', 'img_id', 'seq_id'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename'\n                 ]),\n        ])\n]\n\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=2,\n        dataset=dict(\n            split='train',\n            ref_seq_index=[-2, -1, 1, 2],\n            test_mode=False,\n            pipeline=train_pipeline\n        )),\n    test=dict(\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        split='val',\n    )\n)\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/det/video_knet_kitti_step/video_knet_s3_r50_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_train_8e.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_kitti_step_s3_r50_fpn.py',\n    '../_base_/datasets/kitti_step_vps.py',\n]\n\nload_from = None\n\nnum_stages = 3\nconv_kernel_size = 1\nnum_thing_classes = 2\nnum_stuff_classes = 17\nnum_classes = num_thing_classes + num_stuff_classes\n\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    cityscapes=False,\n    kitti_step=True,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=2,\n    num_stuff_classes=17,\n    ignore_label=255,\n    backbone=dict(\n            type='ResNet',\n            depth=50,\n            num_stages=4,\n            out_indices=(0, 1, 2, 3),\n            frozen_stages=1,\n            norm_cfg=dict(type='SyncBN', requires_grad=True),\n    ),\n    rpn_head=dict(\n        loss_seg=dict(\n                _delete_=True,\n                type='CrossEntropyLoss',\n                use_sigmoid=False,\n                loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add video_knet_vis roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=2,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=19,\n                previous='placeholder',\n                previous_type=\"ffn\",\n                num_thing_classes=2,\n                num_stuff_classes=17,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids',]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img', 'img_id', 'seq_id'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename'\n                 ]),\n        ])\n]\n\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[7])\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=2,\n        dataset=dict(\n            split='train',\n            ref_seq_index=[-2, -1, 1, 2],\n            test_mode=False,\n            pipeline=train_pipeline\n        )),\n    test=dict(\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        split='val',\n    )\n)\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/det/video_knet_kitti_step/video_knet_s3_swinb_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_update.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_kitti_step_s3_r50_fpn.py',\n    '../_base_/datasets/kitti_step_vps.py',\n]\n\nload_from = None\n\nnum_stages = 3\nconv_kernel_size = 1\nnum_thing_classes = 2\nnum_stuff_classes = 17\nnum_classes = num_thing_classes + num_stuff_classes\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    cityscapes=False,\n    kitti_step=True,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=num_thing_classes,\n    num_stuff_classes=num_stuff_classes,\n    ignore_label=255,\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=128,\n        depths=[2, 2, 18, 2],\n        num_heads=[4, 8, 16, 32],\n        window_size=7,\n        mlp_ratio=4,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.3,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False\n    ),\n    neck=dict(in_channels=[128, 256, 512, 1024]),\n    rpn_head=dict(\n        loss_seg=dict(\n                _delete_=True,\n                type='CrossEntropyLoss',\n                use_sigmoid=False,\n                loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add video_knet_vis roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=2,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=num_classes,\n                previous='placeholder',\n                previous_link=\"update_dynamic_cov\",\n                previous_type=\"update\",\n                num_thing_classes=num_thing_classes,\n                num_stuff_classes=num_stuff_classes,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\n\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids',]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img', 'img_id', 'seq_id'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename'\n                 ]),\n        ])\n]\n\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=2,\n        dataset=dict(\n            split='train',\n            ref_seq_index=[-2, -1, 1, 2],\n            test_mode=False,\n            pipeline=train_pipeline\n        )),\n    test=dict(\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        split='val',\n    )\n)\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/det/video_knet_kitti_step/video_knet_s3_swinl_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_joint_update.py",
    "content": "_base_ = [\n    '../../_base_/schedules/schedule_1x.py',\n    '../../_base_/default_runtime.py',\n    '../../_base_/models/knet_kitti_step_s3_r50_fpn.py',\n    '../../_base_/datasets/kitti_step_vps.py',\n]\n\nload_from = None\n\nnum_stages = 3\nconv_kernel_size = 1\nnum_thing_classes = 2\nnum_stuff_classes = 17\nnum_classes = num_thing_classes + num_stuff_classes\n\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    cityscapes=False,\n    kitti_step=True,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=2,\n    num_stuff_classes=17,\n    ignore_label=255,\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=192,\n        depths=[2, 2, 18, 2],\n        num_heads=[6, 12, 24, 48],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.2,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False),\n    neck=dict(in_channels=[192, 384, 768, 1536]),\n    rpn_head=dict(\n        loss_seg=dict(\n                _delete_=True,\n                type='CrossEntropyLoss',\n                use_sigmoid=False,\n                loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add video_knet_vis roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=2,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=19,\n                previous='placeholder',\n                previous_link=\"update_dynamic_cov\",\n                previous_type=\"update\",\n                num_thing_classes=2,\n                num_stuff_classes=17,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\nwork_dir = 'logger/ks_wodepth_4x8_step_stride2_nocrop_2_17'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids',]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img', 'img_id', 'seq_id'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename'\n                 ]),\n        ])\n]\n\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=2,\n        dataset=dict(\n            split='train',\n            ref_seq_index=[-2, -1, 1, 2],\n            test_mode=False,\n            pipeline=train_pipeline\n        )),\n    test=dict(\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        split='val',\n    )\n)\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/det/video_knet_kitti_step/video_knet_s3_swinl_rpn_1x_kitti_step_sigmoid_stride2_mask_embed_link_ffn_update_conv_short_track_fc.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_kitti_step_s3_r50_fpn.py',\n    '../_base_/datasets/kitti_step_vps.py',\n]\n# load_from = \"/mnt/lustre/lixiangtai/project/Knet/work_dirs/city_step/swin_l_joint_8e/latest.pth\"\n\nload_from = None\n\nnum_stages = 3\nconv_kernel_size = 1\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    cityscapes=False,\n    kitti_step=True,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=2,\n    num_stuff_classes=17,\n    ignore_label=255,\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=192,\n        depths=[2, 2, 18, 2],\n        num_heads=[6, 12, 24, 48],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.2,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False),\n    neck=dict(in_channels=[192, 384, 768, 1536]),\n    rpn_head=dict(\n        loss_seg=dict(\n                _delete_=True,\n                type='CrossEntropyLoss',\n                use_sigmoid=False,\n                loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add track roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=1,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=2,\n        num_stuff_classes=17,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=19,\n                previous='placeholder',\n                previous_link=\"update_dynamic_cov\",\n                previous_type=\"ffn\",\n                num_thing_classes=2,\n                num_stuff_classes=17,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesDirect'),\n    dict(type='LoadMultiAnnotationsDirect', with_depth=False, divisor=-1, cherry_pick=True, cherry=[11, 13]),\n    dict(type='SeqResizeWithDepth', img_scale=(384, 1248), ratio_range=[0.5, 2.0], keep_ratio=True),\n    dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n    dict(type='SeqRandomCropWithDepth', crop_size=(384, 1248), share_params=True),\n    dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n    dict(type='SeqPadWithDepth', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_instance_ids',]),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadImgDirect'),\n    dict(\n        type='MultiScaleFlipAug',\n        scale_factor=[1.0],\n        flip=False,\n        transforms=[\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect',\n                 keys=['img', 'img_id', 'seq_id'],\n                 meta_keys=[\n                     'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                     'flip_direction', 'img_norm_cfg', 'ori_filename'\n                 ]),\n        ])\n]\n\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=2,\n        dataset=dict(\n            split='train',\n            ref_seq_index=[-2, -1, 1, 2],\n            test_mode=False,\n            pipeline=train_pipeline\n        )),\n    test=dict(\n        ref_seq_index=None,\n        test_mode=True,\n        pipeline=test_pipeline,\n        split='val',\n    )\n)\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/det/video_knet_vipseg/video_knet_s3_r50_rpn_vipseg_mask_embed_link_ffn_joint_train.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_vipseg_s3_r50_fpn.py',\n    '../_base_/datasets/vipseg_dvps.py',\n]\n\nnum_stages = 3\nconv_kernel_size = 1\nnum_thing_classes = 58\nnum_stuff_classes = 66\nnum_classes = num_stuff_classes + num_thing_classes\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    # use cityscape style label distribution. # thing first , stuff second\n    cityscapes=False,\n    vipseg=True,\n    kitti_step=False,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=num_thing_classes,\n    num_stuff_classes=num_stuff_classes,\n    ignore_label=255,\n    backbone=dict(\n            type='ResNet',\n            depth=50,\n            num_stages=4,\n            out_indices=(0, 1, 2, 3),\n            frozen_stages=1,\n            norm_cfg=dict(type='BN', requires_grad=True),\n            norm_eval=True\n    ),\n    rpn_head=dict(\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n            loss_seg=dict(\n                    _delete_=True,\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add track roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=2,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=num_classes,\n                previous='placeholder',\n                previous_type=\"ffn\",\n                num_thing_classes=num_thing_classes,\n                num_stuff_classes=num_stuff_classes,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\n\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\n\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/det/video_knet_vipseg/video_knet_s3_swin_b_rpn_vipseg_mask_embed_link_ffn_joint_train_8e.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_vipseg_s3_r50_fpn.py',\n    '../_base_/datasets/vipseg_dvps.py',\n]\n\n\nnum_stages = 3\nconv_kernel_size = 1\nnum_thing_classes = 58\nnum_stuff_classes = 66\nnum_classes = num_stuff_classes + num_thing_classes\n\nmodel = dict(\n    type=\"VideoKNetQuansiEmbedFCJointTrain\",\n    # use cityscape style label distribution. # thing first , stuff second\n    cityscapes=False,\n    vipseg=True,\n    kitti_step=False,\n    link_previous=True,\n    mask_assign_stride=2,\n    num_thing_classes=num_thing_classes,\n    num_stuff_classes=num_stuff_classes,\n    ignore_label=255,\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=128,\n        depths=[2, 2, 18, 2],\n        num_heads=[4, 8, 16, 32],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.3,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False),\n    neck=dict(\n        in_channels=[128, 256, 512, 1024],\n    ),\n    rpn_head=dict(\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n            loss_seg=dict(\n                    _delete_=True,\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n        feat_downsample_stride=4,\n    ),\n    # add track roi head\n    track_head=dict(\n        type='QuasiDenseMaskEmbedHeadGTMask',\n        num_convs=0,\n        num_fcs=2,\n        roi_feat_size=1,\n        in_channels=256,\n        fc_out_channels=256,\n        embed_channels=256,\n        norm_cfg=dict(type='GN', num_groups=32),\n        loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n        loss_track_aux=dict(\n            type='L2Loss',\n            neg_pos_ub=3,\n            pos_margin=0,\n            neg_margin=0.1,\n            hard_mining=True,\n            loss_weight=1.0),\n    ),\n    # add tracker config\n    tracker=dict(\n        type='QuasiDenseEmbedTracker',\n        init_score_thr=0.35,\n        obj_score_thr=0.3,\n        match_score_thr=0.5,\n        memo_tracklet_frames=5,\n        memo_backdrop_frames=1,\n        memo_momentum=0.8,\n        nms_conf_thr=0.5,\n        nms_backdrop_iou_thr=0.3,\n        nms_class_iou_thr=0.7,\n        with_cats=True,\n        match_metric='bisoftmax'\n    ),\n    # roi head\n    roi_head=dict(\n        type='VideoKernelIterHead',\n        num_stages=num_stages,\n        num_thing_classes=num_thing_classes,\n        num_stuff_classes=num_stuff_classes,\n        with_track=True,\n        merge_joint=True,\n        mask_head=[\n            dict(\n                type='VideoKernelUpdateHead',\n                num_classes=num_classes,\n                previous='placeholder',\n                previous_type=\"ffn\",\n                num_thing_classes=num_thing_classes,\n                num_stuff_classes=num_stuff_classes,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=4,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0),\n            ) for _ in range(num_stages)\n        ]\n    ),\n    track_train_cfg=dict(\n        assigner=dict(\n            type='MaskHungarianAssigner',\n            cls_cost=dict(type='FocalLossCost', weight=2.0),\n            dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n            mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)),\n        sampler=dict(type='MaskPseudoSampler'),),\n    bbox_roi_extractor=None\n)\n\n\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[7,])\n\n\nfind_unused_parameters=True"
  },
  {
    "path": "configs/video_knet_vis/_base_/datasets/coco_instance.py",
    "content": "dataset_type = 'CocoDataset'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_train2017.json',\n        img_prefix=data_root + 'train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\n# we do not evaluate bbox because K-Net does not predict bounding boxes\nevaluation = dict(metric=['segm'])\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/datasets/youtubevis_2019.py",
    "content": "# dataset settings\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375],\n    to_rgb=True\n)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesFromFile', to_float32=True),\n    dict(\n        type='SeqLoadAnnotations',\n        with_bbox=True,\n        with_mask=True,\n        with_track=True),\n    dict(\n        type='SeqResize',\n        multiscale_mode='value',\n        share_params=True,\n        img_scale=[(288,1e6), (320,1e6), (352,1e6), (392,1e6), (416,1e6), (448,1e6), (480,1e6), (512,1e6)],\n        keep_ratio=True\n    ),\n    dict(type='SeqRandomFlip', share_params=True, flip_ratio=0.5),\n    dict(type='SeqNormalize', **img_norm_cfg),\n    dict(type='SeqPad', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_instance_ids'],\n        reject_empty=True,\n        num_ref_imgs=5,\n    ),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ntest_pipeline = [\n    dict(type='LoadMultiImagesFromFile', to_float32=True),\n    dict(type='MultiScaleFlipAugVideo',\n         img_scale=(640, 360),\n         flip=False,\n         transforms=[\n             dict(type='SeqResize'),\n             dict(type='SeqNormalize', **img_norm_cfg),\n             dict(type='SeqPad', size_divisor=32),\n             dict(\n                 type='VideoCollect',\n                 keys=['img'],\n                 reject_empty=False,\n                 num_ref_imgs=0,  # 0 means do not apply check\n             ),\n             dict(type='ConcatVideoReferences'),\n             dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n         ])\n]\n\ndataset_type = 'YouTubeVISDataset'\ndata_root = 'data/youtube_vis_2019/'\ndataset_version = '2019'\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        dataset_version=dataset_version,\n        ann_file=data_root + 'annotations/youtube_vis_2019_train.json',\n        img_prefix=data_root + 'train/JPEGImages',\n        ref_img_sampler=dict(\n            num_ref_imgs=5,\n            frame_range=[-2, 2],\n            filter_key_img=False,\n            method='uniform'),\n        pipeline=train_pipeline\n    ),\n    val=dict(\n        type=dataset_type,\n        dataset_version=dataset_version,\n        ann_file=data_root + 'annotations/youtube_vis_2019_valid.json',\n        img_prefix=data_root + 'valid/JPEGImages',\n        ref_img_sampler=None,\n        load_all_frames=True,\n        pipeline=test_pipeline\n    ),\n    test=dict(\n        type=dataset_type,\n        dataset_version=dataset_version,\n        ann_file=data_root + 'annotations/youtube_vis_2019_valid.json',\n        img_prefix=data_root + 'valid/JPEGImages',\n        ref_img_sampler=None,\n        load_all_frames=True,\n        pipeline=test_pipeline\n    )\n)\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/default_runtime.py",
    "content": "checkpoint_config = dict(interval=1)\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n    ]\n)\n# custom_hooks = [dict(type='NumClassCheckHook')]\n\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]\n\nwork_dir = 'logger/blackhole'\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/models/knet_track_r50.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='KNetTrack',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=0,\n        add_extra_convs='on_input',\n        num_outs=4),\n    rpn_head=dict(\n        type='ConvKernelHeadVideo',\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)\n        ),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        num_classes=40,\n        feat_transform_cfg=None,\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHeadVideo',\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        num_thing_classes=40,\n        num_stuff_classes=0,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=40,\n                num_thing_classes=40,\n                num_stuff_classes=0,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)\n            ) for _ in range(num_stages)\n        ]),\n    tracker=dict(\n        type=\"KernelFrameIterHeadVideo\",\n        num_proposals=num_proposals,\n        num_stages=3,\n        assign_stages=2,\n        proposal_feature_channel=256,\n        stage_loss_weights=(1., 1., 1.),\n        num_thing_classes=40,\n        num_stuff_classes=0,\n        mask_head=dict(\n            type='KernelUpdateHeadVideo',\n            num_proposals=num_proposals,\n            num_classes=40,\n            num_thing_classes=40,\n            num_stuff_classes=0,\n            num_ffn_fcs=2,\n            num_heads=8,\n            num_cls_fcs=1,\n            num_mask_fcs=1,\n            feedforward_channels=2048,\n            in_channels=256,\n            out_channels=256,\n            dropout=0.0,\n            mask_thr=0.5,\n            conv_kernel_size=conv_kernel_size,\n            mask_upsample_stride=2,\n            ffn_act_cfg=dict(type='ReLU', inplace=True),\n            with_ffn=True,\n            feat_transform_cfg=dict(\n                conv_cfg=dict(type='Conv2d'), act_cfg=None),\n            kernel_updator_cfg=dict(\n                type='KernelUpdator',\n                in_channels=256,\n                feat_channels=256,\n                out_channels=256,\n                input_feat_shape=3,\n                act_cfg=dict(type='ReLU', inplace=True),\n                norm_cfg=dict(type='LN')),\n            loss_mask=dict(\n                type='CrossEntropyLoss',\n                use_sigmoid=True,\n                loss_weight=1.0),\n            loss_dice=dict(\n                type='DiceLoss', loss_weight=4.0),\n            loss_cls=dict(\n                type='FocalLoss',\n                use_sigmoid=True,\n                gamma=2.0,\n                alpha=0.25,\n                loss_weight=2.0)\n        ),\n\n    ),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)\n            ),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)\n                ),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1) for _ in range(num_stages)\n        ],\n        tracker=dict(\n            assigner=dict(\n                type='MaskHungarianAssignerVideo',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0,\n                               pred_act=True)\n            ),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1)\n    ),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=10,\n            mask_thr=0.5,\n            merge_stuff_thing=dict(\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3\n            )\n        ),\n        tracker=dict(\n            max_per_img=10,\n            mask_thr=0.5,\n            merge_stuff_thing=dict(\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3\n            ),\n        ),\n    )\n)\n\ncustom_imports = dict(\n    imports=[\n        'knet_vis.det.knet',\n        'knet_vis.det.kernel_head',\n        'knet_vis.det.kernel_iter_head',\n        'knet_vis.det.kernel_update_head',\n        'knet_vis.det.semantic_fpn_wrapper',\n        'knet_vis.kernel_updator',\n        'knet_vis.det.mask_hungarian_assigner',\n        'knet_vis.det.mask_pseudo_sampler',\n        'knet_vis.tracker.track',\n        'knet_vis.tracker.kernel_head',\n        'knet_vis.tracker.kernel_iter_head',\n        'knet_vis.tracker.kernel_frame_iter_head',\n        'knet_vis.tracker.mask_hungarian_assigner',\n        'knet_vis.tracker.kernel_update_head',\n        'swin.swin_transformer',\n        'mmtrack.datasets.youtube_vis_dataset',\n        'mmtrack.pipelines',\n    ],\n    allow_failed_imports=False\n)\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/models/knet_track_r50_deformablefpn.py",
    "content": "num_stages = 3\nnum_proposals = 100\nconv_kernel_size = 1\nmodel = dict(\n    type='KNetTrack',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='MSDeformAttnPixelDecoder',\n        num_outs=3,\n        norm_cfg=dict(type='GN', num_groups=32),\n        act_cfg=dict(type='ReLU'),\n        return_one_list=True,\n        encoder=dict(\n            type='DetrTransformerEncoder',\n            num_layers=6,\n            transformerlayers=dict(\n                type='BaseTransformerLayer',\n                attn_cfgs=dict(\n                    type='MultiScaleDeformableAttention',\n                    embed_dims=256,\n                    num_heads=8,\n                    num_levels=3,\n                    num_points=4,\n                    im2col_step=64,\n                    dropout=0.0,\n                    batch_first=False,\n                    norm_cfg=None,\n                    init_cfg=None),\n                ffn_cfgs=dict(\n                    type='FFN',\n                    embed_dims=256,\n                    feedforward_channels=1024,\n                    num_fcs=2,\n                    ffn_drop=0.0,\n                    act_cfg=dict(type='ReLU', inplace=True)),\n                operation_order=('self_attn', 'norm', 'ffn', 'norm')),\n            init_cfg=None),\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True),\n        init_cfg=None),\n    rpn_head=dict(\n        type='ConvKernelHeadVideo',\n        conv_kernel_size=conv_kernel_size,\n        feat_downsample_stride=2,\n        feat_refine_stride=1,\n        feat_refine=False,\n        use_binary=True,\n        num_loc_convs=1,\n        num_seg_convs=1,\n        conv_normal_init=True,\n        localization_fpn=dict(\n            type='SemanticFPNWrapper',\n            in_channels=256,\n            feat_channels=256,\n            out_channels=256,\n            start_level=0,\n            end_level=3,\n            upsample_times=2,\n            positional_encoding=dict(\n                type='SinePositionalEncoding', num_feats=128, normalize=True),\n            cat_coors=False,\n            cat_coors_level=3,\n            fuse_by_cat=False,\n            return_list=False,\n            num_aux_convs=1,\n            norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)\n        ),\n        num_proposals=num_proposals,\n        proposal_feats_with_obj=True,\n        xavier_init_kernel=False,\n        kernel_init_std=1,\n        num_cls_fcs=1,\n        in_channels=256,\n        num_classes=40,\n        feat_transform_cfg=None,\n        loss_seg=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_mask=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_dice=dict(type='DiceLoss', loss_weight=4.0)),\n    roi_head=dict(\n        type='KernelIterHeadVideo',\n        num_stages=num_stages,\n        stage_loss_weights=[1] * num_stages,\n        proposal_feature_channel=256,\n        num_thing_classes=40,\n        num_stuff_classes=0,\n        mask_head=[\n            dict(\n                type='KernelUpdateHead',\n                num_classes=40,\n                num_thing_classes=40,\n                num_stuff_classes=0,\n                num_ffn_fcs=2,\n                num_heads=8,\n                num_cls_fcs=1,\n                num_mask_fcs=1,\n                feedforward_channels=2048,\n                in_channels=256,\n                out_channels=256,\n                dropout=0.0,\n                mask_thr=0.5,\n                conv_kernel_size=conv_kernel_size,\n                mask_upsample_stride=2,\n                ffn_act_cfg=dict(type='ReLU', inplace=True),\n                with_ffn=True,\n                feat_transform_cfg=dict(\n                    conv_cfg=dict(type='Conv2d'), act_cfg=None),\n                kernel_updator_cfg=dict(\n                    type='KernelUpdator',\n                    in_channels=256,\n                    feat_channels=256,\n                    out_channels=256,\n                    input_feat_shape=3,\n                    act_cfg=dict(type='ReLU', inplace=True),\n                    norm_cfg=dict(type='LN')),\n                loss_mask=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=True,\n                    loss_weight=1.0),\n                loss_dice=dict(\n                    type='DiceLoss', loss_weight=4.0),\n                loss_cls=dict(\n                    type='FocalLoss',\n                    use_sigmoid=True,\n                    gamma=2.0,\n                    alpha=0.25,\n                    loss_weight=2.0)\n            ) for _ in range(num_stages)\n        ]),\n    tracker=dict(\n        type=\"KernelFrameIterHeadVideo\",\n        num_proposals=num_proposals,\n        num_stages=3,\n        assign_stages=2,\n        proposal_feature_channel=256,\n        stage_loss_weights=(1., 1., 1.),\n        num_thing_classes=40,\n        num_stuff_classes=0,\n        mask_head=dict(\n            type='KernelUpdateHeadVideo',\n            num_proposals=num_proposals,\n            num_classes=40,\n            num_thing_classes=40,\n            num_stuff_classes=0,\n            num_ffn_fcs=2,\n            num_heads=8,\n            num_cls_fcs=1,\n            num_mask_fcs=1,\n            feedforward_channels=2048,\n            in_channels=256,\n            out_channels=256,\n            dropout=0.0,\n            mask_thr=0.5,\n            conv_kernel_size=conv_kernel_size,\n            mask_upsample_stride=2,\n            ffn_act_cfg=dict(type='ReLU', inplace=True),\n            with_ffn=True,\n            feat_transform_cfg=dict(\n                conv_cfg=dict(type='Conv2d'), act_cfg=None),\n            kernel_updator_cfg=dict(\n                type='KernelUpdator',\n                in_channels=256,\n                feat_channels=256,\n                out_channels=256,\n                input_feat_shape=3,\n                act_cfg=dict(type='ReLU', inplace=True),\n                norm_cfg=dict(type='LN')),\n            loss_mask=dict(\n                type='CrossEntropyLoss',\n                use_sigmoid=True,\n                loss_weight=1.0),\n            loss_dice=dict(\n                type='DiceLoss', loss_weight=4.0),\n            loss_cls=dict(\n                type='FocalLoss',\n                use_sigmoid=True,\n                gamma=2.0,\n                alpha=0.25,\n                loss_weight=2.0)\n        ),\n\n    ),\n    # training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaskHungarianAssigner',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0, pred_act=True)\n            ),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaskHungarianAssigner',\n                    cls_cost=dict(type='FocalLossCost', weight=2.0),\n                    dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                    mask_cost=dict(type='MaskCost', weight=1.0,\n                                   pred_act=True)\n                ),\n                sampler=dict(type='MaskPseudoSampler'),\n                pos_weight=1) for _ in range(num_stages)\n        ],\n        tracker=dict(\n            assigner=dict(\n                type='MaskHungarianAssignerVideo',\n                cls_cost=dict(type='FocalLossCost', weight=2.0),\n                dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),\n                mask_cost=dict(type='MaskCost', weight=1.0,\n                               pred_act=True)\n            ),\n            sampler=dict(type='MaskPseudoSampler'),\n            pos_weight=1)\n    ),\n    test_cfg=dict(\n        rpn=None,\n        rcnn=dict(\n            max_per_img=10,\n            mask_thr=0.5,\n            merge_stuff_thing=dict(\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3\n            )\n        ),\n        tracker=dict(\n            max_per_img=10,\n            mask_thr=0.5,\n            merge_stuff_thing=dict(\n                iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3\n            ),\n        ),\n    )\n)\n\ncustom_imports = dict(\n    imports=[\n        'knet_vis.det.knet',\n        'knet_vis.det.kernel_head',\n        'knet_vis.det.kernel_iter_head',\n        'knet_vis.det.kernel_update_head',\n        'knet_vis.det.semantic_fpn_wrapper',\n        'knet_vis.kernel_updator',\n        'knet.det.msdeformattn_decoder',\n        'knet_vis.det.mask_hungarian_assigner',\n        'knet_vis.det.mask_pseudo_sampler',\n        'knet_vis.tracker.track',\n        'knet_vis.tracker.kernel_head',\n        'knet_vis.tracker.kernel_iter_head',\n        'knet_vis.tracker.kernel_frame_iter_head',\n        'knet_vis.tracker.mask_hungarian_assigner',\n        'knet_vis.tracker.kernel_update_head',\n        'swin.swin_transformer',\n        'mmtrack.datasets.youtube_vis_dataset',\n        'mmtrack.pipelines',\n    ],\n    allow_failed_imports=False\n)\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/schedules/schedule_0.75x.py",
    "content": "# optimizer\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(\n        custom_keys={\n            'backbone': dict(lr_mult=0.25)\n        }\n    )\n)\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[5, 7]\n)\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/schedules/schedule_1x.py",
    "content": "# optimizer\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(\n        custom_keys={\n            'backbone': dict(lr_mult=0.25)\n        }\n    )\n)\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[8, 11]\n)\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "configs/video_knet_vis/_base_/schedules/schedule_8e.py",
    "content": "# optimizer\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(\n        custom_keys={\n            'backbone': dict(lr_mult=0.25)\n        }\n    )\n)\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[7, ]\n)\nrunner = dict(type='EpochBasedRunner', max_epochs=8)\n"
  },
  {
    "path": "configs/video_knet_vis/common/mstrain_3x_coco_instance.py",
    "content": "_base_ = '../_base_/default_runtime.py'\n# dataset settings\ndataset_type = 'CocoDataset'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\n# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)],\n# multiscale_mode='range'\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(\n        type='Resize',\n        img_scale=[(1333, 640), (1333, 800)],\n        multiscale_mode='range',\n        keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\n\n# Use RepeatDataset to speed up training\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=3,\n        dataset=dict(\n            type=dataset_type,\n            ann_file=data_root + 'annotations/instances_train2017.json',\n            img_prefix=data_root + 'train2017/',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(interval=1, metric=['segm'])\n\n# optimizer\n# this is different from the original 1x schedule that use SGD\noptimizer = dict(\n    type='AdamW',\n    lr=0.0001,\n    weight_decay=0.05,\n    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.25)}))\noptimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))\n\n# learning policy\n# Experiments show that using step=[9, 11] has higher performance\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[9, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "configs/video_knet_vis/video_knet_vis/knet_track_r50_1x_youtubevis.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_track_r50.py',\n    '../_base_/datasets/youtubevis_2019.py',\n]"
  },
  {
    "path": "configs/video_knet_vis/video_knet_vis/knet_track_r50_deformable_fpn_1x_youtubevis.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_track_r50_deformablefpn.py',\n    '../_base_/datasets/youtubevis_2019.py',\n]\n\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,)\n"
  },
  {
    "path": "configs/video_knet_vis/video_knet_vis/knet_track_swinb_1x_youtubevis_8e.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_8e.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_track_r50.py',\n    '../_base_/datasets/youtubevis_2019.py',\n]\n\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=128,\n        depths=[2, 2, 18, 2],\n        num_heads=[4, 8, 16, 32],\n        window_size=7,\n        mlp_ratio=4,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.3,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=False\n    ),\n    neck=dict(in_channels=[128, 256, 512, 1024]),\n)\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n)"
  },
  {
    "path": "configs/video_knet_vis/video_knet_vis/knet_track_swinb_deformable_1x_youtubevis.py",
    "content": "_base_ = [\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py',\n    '../_base_/models/knet_track_r50_deformablefpn.py',\n    '../_base_/datasets/youtubevis_2019.py',\n]\n\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='SwinTransformerDIY',\n        embed_dims=128,\n        depths=[2, 2, 18, 2],\n        num_heads=[4, 8, 16, 32],\n        window_size=7,\n        mlp_ratio=4,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.3,\n        use_abs_pos_embed=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        with_cp=True\n    ),\n    neck=dict(in_channels=[128, 256, 512, 1024]),\n)\n\ndataset_type = 'YouTubeVISDataset'\ndata_root = 'data/youtube_vis_2019/'\ndataset_version = '2019'\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375],\n    to_rgb=True\n)\n\ntrain_pipeline = [\n    dict(type='LoadMultiImagesFromFile', to_float32=True),\n    dict(\n        type='SeqLoadAnnotations',\n        with_bbox=True,\n        with_mask=True,\n        with_track=True),\n    dict(\n        type='SeqResize',\n        multiscale_mode='value',\n        share_params=True,\n        img_scale=[(288,1e6), (320,1e6), (352,1e6), (392,1e6), (416,1e6), (448,1e6), (480,1e6), (512,1e6)],\n        keep_ratio=True\n    ),\n    dict(type='SeqRandomFlip', share_params=True, flip_ratio=0.5),\n    dict(type='SeqNormalize', **img_norm_cfg),\n    dict(type='SeqPad', size_divisor=32),\n    dict(\n        type='VideoCollect',\n        keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_instance_ids'],\n        reject_empty=True,\n        num_ref_imgs=5,\n    ),\n    dict(type='ConcatVideoReferences'),\n    dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n]\n\ndata = dict(\n    samples_per_gpu=1,\n    workers_per_gpu=2,\n    train=dict(\n        type='RepeatDataset',\n        times=1,\n        dataset=dict(\n        type=dataset_type,\n        dataset_version=dataset_version,\n        ann_file=data_root + 'annotations/youtube_vis_2019_train.json',\n        img_prefix=data_root + 'train/JPEGImages',\n        ref_img_sampler=dict(\n            num_ref_imgs=5,\n            frame_range=[-2, 2],\n            filter_key_img=False,\n            method='uniform'),\n        pipeline=train_pipeline\n    )),\n)"
  },
  {
    "path": "external/cityscape_panoptic.py",
    "content": "import contextlib\nimport io\nimport itertools\nimport os\nimport glob\nimport tempfile\nimport logging\nimport os.path as osp\nfrom collections import OrderedDict\n\nimport pycocotools.mask as maskUtils\n\nimport mmcv\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.coco import CocoDataset\nfrom mmdet.datasets.api_wrappers import COCO, COCOeval\nfrom terminaltables import AsciiTable\nfrom external.coco_panoptic import parse_pq_results, _print_panoptic_results\n\n\n@DATASETS.register_module()\nclass CityscapesPanopticDataset(CocoDataset):\n\n    CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n               'bicycle')\n\n    def load_annotations(self, ann_file):\n        \"\"\"Load annotation from COCO style annotation file.\n\n        Args:\n            ann_file (str): Path of annotation file.\n\n        Returns:\n            list[dict]: Annotation info from COCO api.\n        \"\"\"\n\n        self.coco = COCO(ann_file['ins_ann'])\n        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)\n        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}\n        self.img_ids = sorted(self.coco.get_img_ids())\n\n        self.panoptic_anns = mmcv.load(ann_file['panoptic_ann'])\n\n        self.stuff_ids = [\n            k['id'] for k in self.panoptic_anns['categories']\n            if k['isthing'] == 0\n        ]\n\n        self.thing_ids = [\n            k['id'] for k in self.panoptic_anns['categories']\n            if k['isthing'] == 1\n        ]\n\n        assert self.thing_ids == self.cat_ids\n\n        self.seg2stuff_ids = {\n            i + 1: stuff_id\n            for i, stuff_id in enumerate(self.stuff_ids)\n        }\n\n        self.seg2stuff_ids.update({0: 0})\n\n        self.ins2thing_ids = {\n            i: thing_id\n            for i, thing_id in enumerate(self.thing_ids)\n        }\n\n\n        data_infos = []\n        total_ann_ids = []\n        for i in self.img_ids:\n            info = self.coco.load_imgs([i])[0]\n            info['filename'] = info['file_name']\n            data_infos.append(info)\n            ann_ids = self.coco.get_ann_ids(img_ids=[i])\n            total_ann_ids.extend(ann_ids)\n        assert len(set(total_ann_ids)) == len(\n            total_ann_ids), f\"Annotation ids in '{ann_file}' are not unique!\"\n        return data_infos\n\n    def _filter_imgs(self, min_size=32):\n        \"\"\"Filter images too small or without ground truths.\"\"\"\n        valid_inds = []\n        # obtain images that contain annotation\n        ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())\n        # obtain images that contain annotations of the required categories\n        ids_in_cat = set()\n        for i, class_id in enumerate(self.cat_ids):\n            ids_in_cat |= set(self.coco.cat_img_map[class_id])\n        # merge the image id sets of the two conditions and use the merged set\n        # to filter out images if self.filter_empty_gt=True\n        ids_in_cat &= ids_with_ann\n\n        valid_img_ids = []\n        for i, img_info in enumerate(self.data_infos):\n            img_id = img_info['id']\n            ann_ids = self.coco.getAnnIds(imgIds=[img_id])\n            ann_info = self.coco.loadAnns(ann_ids)\n            all_iscrowd = all([_['iscrowd'] for _ in ann_info])\n            if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat\n                                         or all_iscrowd):\n                continue\n            if min(img_info['width'], img_info['height']) >= min_size:\n                valid_inds.append(i)\n                valid_img_ids.append(img_id)\n        self.img_ids = valid_img_ids\n        return valid_inds\n\n    def _parse_ann_info(self, img_info, ann_info):\n        \"\"\"Parse bbox and mask annotation.\n\n        Args:\n            img_info (dict): Image info of an image.\n            ann_info (list[dict]): Annotation info of an image.\n\n        Returns:\n            dict: A dict containing the following keys: bboxes, \\\n                bboxes_ignore, labels, masks, seg_map. \\\n                \"masks\" are already decoded into binary masks.\n        \"\"\"\n        gt_bboxes = []\n        gt_labels = []\n        gt_bboxes_ignore = []\n        gt_masks_ann = []\n\n        for i, ann in enumerate(ann_info):\n            if ann.get('ignore', False):\n                continue\n            x1, y1, w, h = ann['bbox']\n            if ann['area'] <= 0 or w < 1 or h < 1:\n                continue\n            if ann['category_id'] not in self.cat_ids:\n                continue\n            bbox = [x1, y1, x1 + w, y1 + h]\n            if ann.get('iscrowd', False):\n                gt_bboxes_ignore.append(bbox)\n            else:\n                gt_bboxes.append(bbox)\n                gt_labels.append(self.cat2label[ann['category_id']])\n                gt_masks_ann.append(ann['segmentation'])\n\n        if gt_bboxes:\n            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)\n            gt_labels = np.array(gt_labels, dtype=np.int64)\n        else:\n            gt_bboxes = np.zeros((0, 4), dtype=np.float32)\n            gt_labels = np.array([], dtype=np.int64)\n\n        if gt_bboxes_ignore:\n            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)\n        else:\n            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)\n        ann = dict(\n            bboxes=gt_bboxes,\n            labels=gt_labels,\n            bboxes_ignore=gt_bboxes_ignore,\n            masks=gt_masks_ann,\n            seg_map=img_info['segm_file'])\n\n        return ann\n\n    def _panoptic2json(self, results, outfile_prefix):\n        panoptic_json_results = []\n        mmcv.mkdir_or_exist(outfile_prefix)\n        for idx in range(len(self)):\n            img_id = self.img_ids[idx]\n            panoptic = results[idx]\n            png_string, segments_info = panoptic\n            data = dict()\n            # hack\n            # To match the corresponding ids for panoptic segmentation prediction\n            # for both cityscape vps and cityscapes\n            if self.vps is not None:\n                data['image_id'] = \"_\".join(self.data_infos[idx]['file_name'].split(\".\")[0].split(\"_\")[:5])\n            else:\n                data['image_id'] = self.data_infos[idx]['file_name'].split(\"/\")[-1].split(\".\")[0][:-12]\n\n            for segment_info in segments_info:\n                isthing = segment_info.pop('isthing')\n                cat_id = segment_info['category_id']\n                if isthing is True:\n                    segment_info['category_id'] = self.ins2thing_ids[cat_id]\n                else:\n                    segment_info['category_id'] = self.seg2stuff_ids[cat_id]\n\n            png_path = self.data_infos[idx]['file_name'].replace(\n                '.jpg', '.png')\n            # hack: to save all the images into one folder\n            png_path = png_path.split(\"/\")[-1]\n            png_save_path = osp.join(outfile_prefix, png_path)\n\n            data['file_name'] = png_path\n\n            with open(png_save_path, 'wb') as f:\n                f.write(png_string)\n            data['segments_info'] = segments_info\n            panoptic_json_results.append(data)\n        return panoptic_json_results\n\n    def results2json(self, results, outfile_prefix):\n        \"\"\"Dump the detection results to a COCO style json file.\n\n        There are 3 types of results: proposals, bbox predictions, mask\n        predictions, and they have different data types. This method will\n        automatically recognize the type, and dump them to json files.\n\n        Args:\n            results (list[list | tuple | ndarray]): Testing results of the\n                dataset.\n            outfile_prefix (str): The filename prefix of the json files. If the\n                prefix is \"somepath/xxx\", the json files will be named\n                \"somepath/xxx.bbox.json\", \"somepath/xxx.segm.json\",\n                \"somepath/xxx.proposal.json\".\n\n        Returns:\n            dict[str: str]: Possible keys are \"bbox\", \"segm\", \"proposal\", and \\\n                values are corresponding filenames.\n        \"\"\"\n        result_files = dict()\n        if isinstance(results[0], list):\n            json_results = self._det2json(results)\n            result_files['bbox'] = f'{outfile_prefix}.bbox.json'\n            result_files['proposal'] = f'{outfile_prefix}.bbox.json'\n            mmcv.dump(json_results, result_files['bbox'])\n        elif isinstance(results[0], tuple):\n            if len(results[0]) == 3:  # dump the panoptic\n                instance_segm_results = []\n                panoptic_results = []\n                for idx in range(len(self)):\n                    det, seg, panoptic = results[idx]\n                    instance_segm_results.append([det, seg])\n                    panoptic_results.append(panoptic)\n                panoptic_json = dict()\n                panoptic_json['annotations'] = self._panoptic2json(\n                    panoptic_results, outfile_prefix)\n                result_files['panoptic'] = f'{outfile_prefix}.panoptic.json'\n                mmcv.dump(panoptic_json, result_files['panoptic'])\n            else:\n                instance_segm_results = results\n            json_results = self._segm2json(instance_segm_results)\n            result_files['bbox'] = f'{outfile_prefix}.bbox.json'\n            result_files['proposal'] = f'{outfile_prefix}.bbox.json'\n            result_files['segm'] = f'{outfile_prefix}.segm.json'\n            mmcv.dump(json_results[0], result_files['bbox'])\n            mmcv.dump(json_results[1], result_files['segm'])\n        elif isinstance(results[0], np.ndarray):\n            json_results = self._proposal2json(results)\n            result_files['proposal'] = f'{outfile_prefix}.proposal.json'\n            mmcv.dump(json_results, result_files['proposal'])\n        else:\n            raise TypeError('invalid type of results')\n        return result_files\n\n    def results2txt(self, results, outfile_prefix):\n        \"\"\"Dump the detection results to a txt file.\n\n        Args:\n            results (list[list | tuple]): Testing results of the\n                dataset.\n            outfile_prefix (str): The filename prefix of the json files.\n                If the prefix is \"somepath/xxx\",\n                the txt files will be named \"somepath/xxx.txt\".\n\n        Returns:\n            list[str]: Result txt files which contains corresponding \\\n                instance segmentation images.\n        \"\"\"\n        try:\n            import cityscapesscripts.helpers.labels as CSLabels\n        except ImportError:\n            raise ImportError('Please run \"pip install citscapesscripts\" to '\n                              'install cityscapesscripts first.')\n        result_files = []\n        os.makedirs(outfile_prefix, exist_ok=True)\n        prog_bar = mmcv.ProgressBar(len(self))\n        for idx in range(len(self)):\n            result = results[idx]\n            filename = self.data_infos[idx]['filename']\n            basename = osp.splitext(osp.basename(filename))[0]\n            pred_txt = osp.join(outfile_prefix, basename + '_pred.txt')\n\n            bbox_result, segm_result = result\n            bboxes = np.vstack(bbox_result)\n            # segm results\n            if isinstance(segm_result, tuple):\n                # Some detectors use different scores for bbox and mask,\n                # like Mask Scoring R-CNN. Score of segm will be used instead\n                # of bbox score.\n                segms = mmcv.concat_list(segm_result[0])\n                mask_score = segm_result[1]\n            else:\n                # use bbox score for mask score\n                segms = mmcv.concat_list(segm_result)\n                mask_score = [bbox[-1] for bbox in bboxes]\n            labels = [\n                np.full(bbox.shape[0], i, dtype=np.int32)\n                for i, bbox in enumerate(bbox_result)\n            ]\n            labels = np.concatenate(labels)\n\n            assert len(bboxes) == len(segms) == len(labels)\n            num_instances = len(bboxes)\n            prog_bar.update()\n            with open(pred_txt, 'w') as fout:\n                for i in range(num_instances):\n                    pred_class = labels[i]\n                    classes = self.CLASSES[pred_class]\n                    class_id = CSLabels.name2label[classes].id\n                    score = mask_score[i]\n                    mask = maskUtils.decode(segms[i]).astype(np.uint8)\n                    png_filename = osp.join(outfile_prefix,\n                                            basename + f'_{i}_{classes}.png')\n                    mmcv.imwrite(mask, png_filename)\n                    fout.write(f'{osp.basename(png_filename)} {class_id} '\n                               f'{score}\\n')\n            result_files.append(pred_txt)\n\n        return result_files\n\n    def format_results(self, results, jsonfile_prefix=\"./test\", **kwargs):\n        \"\"\"Format the results to json (standard format for COCO evaluation).\n\n        Args:\n            results (list[tuple | numpy.ndarray]): Testing results of the\n                dataset.\n            jsonfile_prefix (str | None): The prefix of json files. It includes\n                the file path and the prefix of filename, e.g., \"a/b/prefix\".\n                If not specified, a temp file will be created. Default: None.\n\n        Returns:\n            tuple: (result_files, tmp_dir), result_files is a dict containing \\\n                the json filepaths, tmp_dir is the temporal directory created \\\n                for saving json files when jsonfile_prefix is not specified.\n        \"\"\"\n        assert isinstance(results, list), 'results must be a list'\n        assert len(results) == len(self), (\n            'The length of results is not equal to the dataset len: {} != {}'.\n            format(len(results), len(self)))\n\n        if jsonfile_prefix is None:\n            tmp_dir = tempfile.TemporaryDirectory()\n            jsonfile_prefix = osp.join(tmp_dir.name, 'results')\n        else:\n            tmp_dir = None\n        result_files = self.results2json(results, jsonfile_prefix)\n        return result_files, tmp_dir\n\n    def evaluate(self,\n                 results,\n                 metric='bbox',\n                 logger=None,\n                 outfile_prefix=None,\n                 classwise=False,\n                 proposal_nums=(100, 300, 1000),\n                 iou_thrs=np.arange(0.5, 0.96, 0.05),\n                 metric_items = None):\n        \"\"\"Evaluation in Cityscapes/COCO protocol.\n\n        Args:\n            results (list[list | tuple]): Testing results of the dataset.\n            metric (str | list[str]): Metrics to be evaluated. Options are\n                'bbox', 'segm', 'proposal', 'proposal_fast'.\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n            outfile_prefix (str | None): The prefix of output file. It includes\n                the file path and the prefix of filename, e.g., \"a/b/prefix\".\n                If results are evaluated with COCO protocol, it would be the\n                prefix of output json file. For example, the metric is 'bbox'\n                and 'segm', then json files would be \"a/b/prefix.bbox.json\" and\n                \"a/b/prefix.segm.json\".\n                If results are evaluated with cityscapes protocol, it would be\n                the prefix of output txt/png files. The output files would be\n                png images under folder \"a/b/prefix/xxx/\" and the file name of\n                images would be written into a txt file\n                \"a/b/prefix/xxx_pred.txt\", where \"xxx\" is the video name of\n                cityscapes. If not specified, a temp file will be created.\n                Default: None.\n            classwise (bool): Whether to evaluating the AP for each class.\n            proposal_nums (Sequence[int]): Proposal number used for evaluating\n                recalls, such as recall@100, recall@1000.\n                Default: (100, 300, 1000).\n            iou_thrs (Sequence[float]): IoU threshold used for evaluating\n                recalls. If set to a list, the average recall of all IoUs will\n                also be computed. Default: 0.5.\n\n        Returns:\n            dict[str, float]: COCO style evaluation metric or cityscapes mAP \\\n                and AP@50.\n        \"\"\"\n        eval_results = dict()\n\n        metrics = metric.copy() if isinstance(metric, list) else [metric]\n        allowed_metrics = [\n            'bbox', 'segm', 'cityscapes', 'panoptic'\n        ]\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported')\n\n        if 'cityscapes' in metrics:\n            eval_results.update(\n                self._evaluate_cityscapes(results, outfile_prefix, logger))\n            metrics.remove('cityscapes')\n\n        if iou_thrs is None:\n            iou_thrs = np.linspace(\n                .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)\n        if metric_items is not None:\n            if not isinstance(metric_items, list):\n                metric_items = [metric_items]\n\n        result_files, tmp_dir = self.format_results(results, outfile_prefix)\n\n        eval_results = OrderedDict()\n        cocoGt = self.coco\n        for metric in metrics:\n            msg = f'Evaluating {metric}...'\n            if logger is None:\n                msg = '\\n' + msg\n            print_log(msg, logger=logger)\n\n            if metric == 'proposal_fast':\n                ar = self.fast_eval_recall(\n                    results, proposal_nums, iou_thrs, logger='silent')\n                log_msg = []\n                for i, num in enumerate(proposal_nums):\n                    eval_results[f'AR@{num}'] = ar[i]\n                    log_msg.append(f'\\nAR@{num}\\t{ar[i]:.4f}')\n                log_msg = ''.join(log_msg)\n                print_log(log_msg, logger=logger)\n                continue\n\n            if metric == 'panoptic':\n                from panopticapi.evaluation import pq_compute\n\n                with contextlib.redirect_stdout(io.StringIO()):\n                    pq_res = pq_compute(\n                        self.ann_file['panoptic_ann'],\n                        result_files['panoptic'],\n                        gt_folder=self.seg_prefix,\n                        pred_folder=result_files['panoptic'].split('.')[0])\n                results = parse_pq_results(pq_res)\n                for k, v in results.items():\n                    eval_results[f'{metric}_{k}'] = f'{float(v):0.3f}'\n                print_log(\n                    'Panoptic Evaluation Results:\\n' +\n                    _print_panoptic_results(pq_res),\n                    logger=logger)\n                continue\n\n            iou_type = 'bbox' if metric == 'proposal' else metric\n            if metric not in result_files:\n                raise KeyError(f'{metric} is not in results')\n            try:\n                predictions = mmcv.load(result_files[metric])\n                if iou_type == 'segm':\n                    # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331  # noqa\n                    # When evaluating mask AP, if the results contain bbox,\n                    # cocoapi will use the box area instead of the mask area\n                    # for calculating the instance area. Though the overall AP\n                    # is not affected, this leads to different small, medium,\n                    # and large mask AP results.\n                    for x in predictions:\n                        x.pop('bbox')\n                cocoDt = cocoGt.loadRes(predictions)\n            except IndexError:\n                print_log(\n                    'The testing results of the whole dataset is empty.',\n                    logger=logger,\n                    level=logging.ERROR)\n                break\n\n            cocoEval = COCOeval(cocoGt, cocoDt, iou_type)\n            cocoEval.params.catIds = self.cat_ids\n            cocoEval.params.imgIds = self.img_ids\n            cocoEval.params.maxDets = list(proposal_nums)\n            cocoEval.params.iouThrs = iou_thrs\n            # mapping of cocoEval.stats\n            coco_metric_names = {\n                'mAP': 0,\n                'mAP_50': 1,\n                'mAP_75': 2,\n                'mAP_s': 3,\n                'mAP_m': 4,\n                'mAP_l': 5,\n                'AR@100': 6,\n                'AR@300': 7,\n                'AR@1000': 8,\n                'AR_s@1000': 9,\n                'AR_m@1000': 10,\n                'AR_l@1000': 11\n            }\n            if metric_items is not None:\n                for metric_item in metric_items:\n                    if metric_item not in coco_metric_names:\n                        raise KeyError(\n                            f'metric item {metric_item} is not supported')\n\n            if metric == 'proposal':\n                cocoEval.params.useCats = 0\n                cocoEval.evaluate()\n                cocoEval.accumulate()\n                cocoEval.summarize()\n                if metric_items is None:\n                    metric_items = [\n                        'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',\n                        'AR_m@1000', 'AR_l@1000'\n                    ]\n\n                for item in metric_items:\n                    val = float(\n                        f'{cocoEval.stats[coco_metric_names[item]]:.3f}')\n                    eval_results[item] = val\n            else:\n                cocoEval.evaluate()\n                cocoEval.accumulate()\n                cocoEval.summarize()\n                if classwise:  # Compute per-category AP\n                    # Compute per-category AP\n                    # from https://github.com/facebookresearch/detectron2/\n                    precisions = cocoEval.eval['precision']\n                    # precision: (iou, recall, cls, area range, max dets)\n                    assert len(self.cat_ids) == precisions.shape[2]\n\n                    results_per_category = []\n                    for idx, catId in enumerate(self.cat_ids):\n                        # area range index 0: all area ranges\n                        # max dets index -1: typically 100 per image\n                        nm = self.coco.loadCats(catId)[0]\n                        precision = precisions[:, :, idx, 0, -1]\n                        precision = precision[precision > -1]\n                        if precision.size:\n                            ap = np.mean(precision)\n                        else:\n                            ap = float('nan')\n                        results_per_category.append(\n                            (f'{nm[\"name\"]}', f'{float(ap):0.3f}'))\n\n                    num_columns = min(6, len(results_per_category) * 2)\n                    results_flatten = list(\n                        itertools.chain(*results_per_category))\n                    headers = ['category', 'AP'] * (num_columns // 2)\n                    results_2d = itertools.zip_longest(*[\n                        results_flatten[i::num_columns]\n                        for i in range(num_columns)\n                    ])\n                    table_data = [headers]\n                    table_data += [result for result in results_2d]\n                    table = AsciiTable(table_data)\n                    print_log('\\n' + table.table, logger=logger)\n\n                if metric_items is None:\n                    metric_items = [\n                        'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'\n                    ]\n\n                for metric_item in metric_items:\n                    key = f'{metric}_{metric_item}'\n                    val = float(\n                        f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'\n                    )\n                    eval_results[key] = val\n                ap = cocoEval.stats[:6]\n                eval_results[f'{metric}_mAP_copypaste'] = (\n                    f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '\n                    f'{ap[4]:.3f} {ap[5]:.3f}')\n\n        if tmp_dir is not None:\n            tmp_dir.cleanup()\n        return eval_results\n\n    def _evaluate_cityscapes(self, results, txtfile_prefix, logger):\n        \"\"\"Evaluation in Cityscapes protocol.\n\n        Args:\n            results (list): Testing results of the dataset.\n            txtfile_prefix (str | None): The prefix of output txt file\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n\n        Returns:\n            dict[str: float]: Cityscapes evaluation results, contains 'mAP' \\\n                and 'AP@50'.\n        \"\"\"\n\n        try:\n            import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval  # noqa\n        except ImportError:\n            raise ImportError('Please run \"pip install citscapesscripts\" to '\n                              'install cityscapesscripts first.')\n        msg = 'Evaluating in Cityscapes style'\n        if logger is None:\n            msg = '\\n' + msg\n        print_log(msg, logger=logger)\n\n        result_files, tmp_dir = self.format_results(results, txtfile_prefix)\n\n        if tmp_dir is None:\n            result_dir = osp.join(txtfile_prefix, 'results')\n        else:\n            result_dir = osp.join(tmp_dir.name, 'results')\n\n        eval_results = OrderedDict()\n        print_log(f'Evaluating results under {result_dir} ...', logger=logger)\n\n        # set global states in cityscapes evaluation API\n        CSEval.args.cityscapesPath = os.path.join(self.img_prefix, '../..')\n        CSEval.args.predictionPath = os.path.abspath(result_dir)\n        CSEval.args.predictionWalk = None\n        CSEval.args.JSONOutput = False\n        CSEval.args.colorized = False\n        CSEval.args.gtInstancesFile = os.path.join(result_dir,\n                                                   'gtInstances.json')\n        CSEval.args.groundTruthSearch = os.path.join(\n            self.img_prefix.replace('leftImg8bit', 'gtFine'),\n            '*/*_gtFine_instanceIds.png')\n\n        groundTruthImgList = glob.glob(CSEval.args.groundTruthSearch)\n        assert len(groundTruthImgList), 'Cannot find ground truth images' \\\n            f' in {CSEval.args.groundTruthSearch}.'\n        predictionImgList = []\n        for gt in groundTruthImgList:\n            predictionImgList.append(CSEval.getPrediction(gt, CSEval.args))\n        CSEval_results = CSEval.evaluateImgLists(predictionImgList,\n                                                 groundTruthImgList,\n                                                 CSEval.args)['averages']\n\n        eval_results['mAP'] = CSEval_results['allAp']\n        eval_results['AP@50'] = CSEval_results['allAp50%']\n        if tmp_dir is not None:\n            tmp_dir.cleanup()\n        return eval_results"
  },
  {
    "path": "external/cityscapes_step.py",
    "content": "import os\n\nimport numpy as np\n\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.pipelines.compose import Compose\n\nfrom external.dataset.mIoU import eval_miou\n\n\n@DATASETS.register_module()\nclass CityscapesSTEP:\n    CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',\n               'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',\n               'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n               'bicycle')\n\n    def __init__(\n            self,\n            pipeline=None,\n            data_root=None,\n            test_mode=False,\n            split='train',\n    ):\n        # Let's figure out where is the cityscapes first\n        assert os.path.exists(os.path.join(data_root, 'license.txt')), \\\n            \"It seems that '{}' is not the root folder of cityscapes\".format(data_root)\n        assert os.path.exists(os.path.join(data_root, 'leftImg8bit')), \\\n            \"leftImg8bit cannot be found.\"\n        assert os.path.exists(os.path.join(data_root, 'gtFine')), \\\n            \"gtFine cannot be found.\"\n\n        if pipeline is None:\n            pipeline = []\n\n        image_main_dir = os.path.join(data_root, 'leftImg8bit', split)\n        gt_dir = os.path.join(data_root, 'gtFine', split)\n\n        locations = os.listdir(image_main_dir)\n        samples = []\n        for loc in locations:\n            for sample in os.listdir(os.path.join(image_main_dir, loc)):\n                location, seq_id, img_id, _ = sample.split('_')\n                assert location == loc\n                samples.append((location, int(seq_id), int(img_id)))\n        samples = sorted(samples)\n        self.samples = samples\n\n        # Set the image dirs\n        self.gt_dir = gt_dir\n        self.img_dir = image_main_dir\n\n        self.pipeline = Compose(pipeline)\n        self.load_ann_pipeline = Compose([\n            dict(\n                type='LoadAnnotationsInstanceMasks',\n                with_mask=False,\n                with_seg=True,\n                with_inst=True,\n            ),\n        ])\n        self.test_mode = test_mode\n\n        self.flag = self._set_groups()\n\n        # eval\n        self.max_ins = 1000\n        self.no_obj_id = 255\n\n    def pre_pipeline(self, results):\n        results['img_prefix'] = None\n        results['img_fields'] = []\n        results['mask_fields'] = []\n        results['seg_fields'] = []\n        results['bbox_fields'] = []\n        return results\n\n    def prepare_test_img(self, idx):\n        get_idx = self.samples[idx]\n        filename = os.path.join(self.img_dir, get_idx[0], '{}_{:06d}_{:06d}_leftImg8bit.png'.format(*get_idx))\n        results = {\n            'img_info': {\n                'filename': filename\n            }\n        }\n        results = self.pre_pipeline(results)\n        return self.pipeline(results)\n\n    def prepare_val_annotation(self, idx):\n        get_idx = self.samples[idx]\n        results = {\n            'ann_info': {\n                'seg_map': os.path.join(self.gt_dir, get_idx[0],\n                                        '{}_{:06d}_{:06d}_gtFine_labelTrainIds.png'.format(*get_idx)),\n                'inst_map': os.path.join(self.gt_dir, get_idx[0],\n                                         '{}_{:06d}_{:06d}_gtFine_instanceTrainIds.png'.format(*get_idx)),\n            }\n        }\n        results = self.pre_pipeline(results)\n        return self.load_ann_pipeline(results)\n\n    def prepare_train_img(self, idx):\n        get_idx = self.samples[idx]\n        filename = os.path.join(self.img_dir, get_idx[0], '{}_{:06d}_{:06d}_leftImg8bit.png'.format(*get_idx))\n        results = {\n            'img_info': {\n                'filename': filename\n            },\n            'ann_info': {\n                'seg_map': os.path.join(self.gt_dir, get_idx[0],\n                                        '{}_{:06d}_{:06d}_gtFine_labelTrainIds.png'.format(*get_idx)),\n                'inst_map': os.path.join(self.gt_dir, get_idx[0],\n                                         '{}_{:06d}_{:06d}_gtFine_instanceTrainIds.png'.format(*get_idx)),\n            }\n        }\n        results = self.pre_pipeline(results)\n        return self.pipeline(results)\n\n    # Copy and Modify from mmdet\n    def __getitem__(self, idx):\n        \"\"\"Get training/test data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training/test data (with annotation if `test_mode` is set \\\n                True).\n        \"\"\"\n\n        if self.test_mode:\n            return self.prepare_test_img(idx)\n        else:\n            while True:\n                cur_data = self.prepare_train_img(idx)\n                if cur_data is None:\n                    idx = self._rand_another(idx)\n                    continue\n                return cur_data\n\n    def _rand_another(self, idx):\n        \"\"\"Get another random index from the same group as the given index.\"\"\"\n        pool = np.arange(len(self))\n        return np.random.choice(pool)\n\n    def __len__(self):\n        return len(self.samples)\n\n    def _set_groups(self):\n        return np.zeros((len(self)), dtype=np.int64)\n\n    # The evaluate func\n    def evaluate(\n            self,\n            results,\n            **kwargs\n    ):\n        # logger and metric\n        thing_lower = 11\n        thing_upper = 19\n\n        num_thing_classes = 8\n        num_stuff_classes = 11\n        pred_results_handled = []\n        sem_preds = []\n\n        thing_knet2real = [11, 13]\n\n        for item in results:\n            bbox_results, mask_results, seg_results, _, _ = item\n            # in seg_info id starts from 1\n            inst_map, seg_info = seg_results\n            cat_map = np.zeros_like(inst_map) + num_thing_classes + num_stuff_classes\n            for instance in seg_info:\n                cat_cur = instance['category_id']\n                if instance['isthing']:\n                    cat_cur = thing_knet2real[cat_cur]\n                else:\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in thing_knet2real:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                assert cat_cur < num_thing_classes + num_stuff_classes\n                cat_map[inst_map == instance['id']] = cat_cur\n                if not instance['isthing']:\n                    inst_map[inst_map == instance['id']] = 0\n            pred_results_handled.append(cat_map.astype(np.int32) * self.max_ins + inst_map.astype(np.int32))\n            sem_preds.append(cat_map)\n\n        gt_panseg = []\n        sem_targets = []\n        for idx in range(len(self)):\n            results = self.prepare_val_annotation(idx)\n            panseg_map = results['gt_instance_map']\n            sem_targets.append(panseg_map // self.max_ins)\n            gt_panseg.append(panseg_map)\n\n        vpq_results = []\n        for pred, gt in zip(pred_results_handled, gt_panseg):\n            vpq_result = vpq_eval([pred, gt])\n            vpq_results.append(vpq_result)\n\n        iou_per_class = np.stack([result[0] for result in vpq_results]).sum(axis=0)[\n                        :num_thing_classes + num_stuff_classes]\n        tp_per_class = np.stack([result[1] for result in vpq_results]).sum(axis=0)[\n                       :num_thing_classes + num_stuff_classes]\n        fn_per_class = np.stack([result[2] for result in vpq_results]).sum(axis=0)[\n                       :num_thing_classes + num_stuff_classes]\n        fp_per_class = np.stack([result[3] for result in vpq_results]).sum(axis=0)[\n                       :num_thing_classes + num_stuff_classes]\n\n        # calculate the PQs\n        epsilon = 0.\n        sq = iou_per_class / (tp_per_class + epsilon)\n        rq = tp_per_class / (tp_per_class + 0.5 *\n                             fn_per_class + 0.5 * fp_per_class + epsilon)\n        pq = sq * rq\n        # stuff_pq = pq[:num_stuff_classes]\n        # things_pq = pq[num_stuff_classes:]\n        things_index = np.zeros((19,)).astype(bool)\n        things_index[11] = True\n        things_index[13] = True\n        stuff_pq = pq[np.logical_not(things_index)]\n        things_pq = pq[things_index]\n\n        miou_per_class = eval_miou(sem_preds, sem_targets, num_classes=num_thing_classes + num_stuff_classes)\n\n        pq = sq * rq\n        print(\"class        pq\\t\\tsq\\t\\trq\\t\\ttp\\t\\tfp\\t\\tfn\\t\\tmIoU\")\n\n        for i in range(len(self.CLASSES)):\n            print(\"{}{}{:.3f}\\t\\t{:.3f}\\t\\t{:.3f}\\t\\t{:.0f}\\t\\t{:.0f}\\t\\t{:.0f}\\t\\t{:.3f}\".format(\n                self.CLASSES[i], ' '*(13 - len(self.CLASSES[i])), pq[i], sq[i], rq[i], tp_per_class[i],\n                fp_per_class[i], fn_per_class[i], miou_per_class[i]\n            ))\n\n        return {\n            \"PQ\": np.nan_to_num(pq).mean() * 100,\n            \"Stuff PQ\": np.nan_to_num(stuff_pq).mean() * 100,\n            \"Things PQ\": np.nan_to_num(things_pq).mean() * 100,\n            \"mIoU\":np.nan_to_num(miou_per_class).mean() * 100,\n        }\n\n\ndef vpq_eval(element):\n    import six\n    pred_ids, gt_ids = element\n    max_ins = 1000\n    ign_id = 255\n    offset = 256 * 256\n    num_cat = 19 + 1\n\n    iou_per_class = np.zeros(num_cat, dtype=np.float64)\n    tp_per_class = np.zeros(num_cat, dtype=np.float64)\n    fn_per_class = np.zeros(num_cat, dtype=np.float64)\n    fp_per_class = np.zeros(num_cat, dtype=np.float64)\n\n    def _ids_to_counts(id_array):\n        ids, counts = np.unique(id_array, return_counts=True)\n        return dict(six.moves.zip(ids, counts))\n\n    pred_areas = _ids_to_counts(pred_ids)\n    gt_areas = _ids_to_counts(gt_ids)\n\n    void_id = ign_id * max_ins\n    ign_ids = {\n        gt_id for gt_id in six.iterkeys(gt_areas)\n        if (gt_id // max_ins) == ign_id\n    }\n\n    int_ids = gt_ids.astype(np.int64) * offset + pred_ids.astype(np.int64)\n    int_areas = _ids_to_counts(int_ids)\n\n    def prediction_void_overlap(pred_id):\n        void_int_id = void_id * offset + pred_id\n        return int_areas.get(void_int_id, 0)\n\n    def prediction_ignored_overlap(pred_id):\n        total_ignored_overlap = 0\n        for _ign_id in ign_ids:\n            int_id = _ign_id * offset + pred_id\n            total_ignored_overlap += int_areas.get(int_id, 0)\n        return total_ignored_overlap\n\n    gt_matched = set()\n    pred_matched = set()\n\n    for int_id, int_area in six.iteritems(int_areas):\n        gt_id = int(int_id // offset)\n        gt_cat = int(gt_id // max_ins)\n        pred_id = int(int_id % offset)\n        pred_cat = int(pred_id // max_ins)\n        if gt_cat != pred_cat:\n            continue\n        union = (\n                gt_areas[gt_id] + pred_areas[pred_id] - int_area -\n                prediction_void_overlap(pred_id)\n        )\n        iou = int_area / union\n        if iou > 0.5:\n            tp_per_class[gt_cat] += 1\n            iou_per_class[gt_cat] += iou\n            gt_matched.add(gt_id)\n            pred_matched.add(pred_id)\n\n    for gt_id in six.iterkeys(gt_areas):\n        if gt_id in gt_matched:\n            continue\n        cat_id = gt_id // max_ins\n        if cat_id == ign_id:\n            continue\n        fn_per_class[cat_id] += 1\n\n    for pred_id in six.iterkeys(pred_areas):\n        if pred_id in pred_matched:\n            continue\n        if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5:\n            continue\n        cat = pred_id // max_ins\n        fp_per_class[cat] += 1\n\n    return iou_per_class, tp_per_class, fn_per_class, fp_per_class\n\n\nif __name__ == '__main__':\n    import dataset.pipelines.loading\n    import dataset.pipelines.transforms\n\n    img_norm_cfg = dict(\n        mean=[123.675, 116.28, 103.53],\n        std=[58.395, 57.12, 57.375],\n        to_rgb=True\n    )\n    train_pipelines = [\n        dict(type='LoadImageFromFile'),\n        dict(type='LoadAnnotationsInstanceMasks', cherry=[11, 13]),\n        dict(type='KNetInsAdapterCherryPick', stuff_nums=11, cherry=[11, 13]),\n        dict(type='Resize', img_scale=(1024, 2048), ratio_range=[0.5, 2.0], keep_ratio=True),\n        dict(type='RandomFlip', flip_ratio=0.5),\n        dict(type='RandomCrop', crop_size=(1024, 2048)),\n        dict(type='Normalize', **img_norm_cfg),\n        dict(type='PadFutureMMDet', size_divisor=32, pad_val=dict(img=0, masks=0, seg=255)),\n        dict(type='DefaultFormatBundle'),\n        dict(type='Collect', keys=['img', 'gt_masks', 'gt_labels', 'gt_semantic_seg'],\n             meta_keys=('ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                        'flip_direction', 'img_norm_cfg')\n             ),\n    ]\n    data = CityscapesSTEP(\n        pipeline=train_pipelines,\n        data_root='data/cityscapes',\n        split='train',\n        test_mode=False\n    )\n    for item in data:\n        print(item)\n"
  },
  {
    "path": "external/cityscapes_vps.py",
    "content": "import contextlib\nimport io\nimport itertools\nimport os\nimport glob\nimport tempfile\nimport logging\nimport os.path as osp\nfrom collections import OrderedDict\n\nimport pycocotools.mask as maskUtils\n\nimport mmcv\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.coco import CocoDataset\nfrom mmdet.datasets.api_wrappers import COCO, COCOeval\nfrom terminaltables import AsciiTable\nfrom external.coco_panoptic import parse_pq_results, _print_panoptic_results\n\n\n@DATASETS.register_module()\nclass CityscapesVPSDataset(CocoDataset):\n    def __init__(self,\n                 ann_file,\n                 pipeline,\n                 data_root=None,\n                 img_prefix=None,\n                 seg_prefix=None,\n                 proposal_file=None,\n                 test_mode=False,\n                 offsets=None,\n                 ref_prefix=None,\n                 nframes_span_test=6):\n        super(CityscapesVPSDataset, self).__init__(\n            ann_file=ann_file,\n            pipeline=pipeline,\n            data_root=data_root,\n            img_prefix=img_prefix,\n            seg_prefix=seg_prefix,\n            proposal_file=proposal_file,\n            test_mode=test_mode)\n\n        # Hack: we use ref_img_infos to load reference images.\n        self.ref_img_infos = self.load_ref_annotations(\n                    self.ann_file)\n        self.ref_prefix = ref_prefix\n        self.offsets = offsets\n        self.nframes_span_test = nframes_span_test\n        self.iid2_img_infos = {x['id']: x for x in self.ref_img_infos}\n\n    CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n               'bicycle')\n\n    def load_ref_annotations(self, ann_file):\n        self.ref_coco = COCO(ann_file['ins_ann'])\n        self.ref_cat_ids = self.ref_coco.getCatIds()\n        self.ref_cat2label = {\n            cat_id: i + 1\n            for i, cat_id in enumerate(self.ref_cat_ids)\n        }\n        self.ref_img_ids = self.ref_coco.getImgIds()\n        img_infos = []\n        for i in self.ref_img_ids:\n            info = self.ref_coco.loadImgs([i])[0]\n            info['filename'] = info['file_name']\n            img_infos.append(info)\n        return img_infos\n\n    def load_annotations(self, ann_file):\n        \"\"\"Load annotation from COCO style annotation file.\n\n        Args:\n            ann_file (str): Path of annotation file.\n\n        Returns:\n            list[dict]: Annotation info from COCO api.\n        \"\"\"\n\n        self.coco = COCO(ann_file['ins_ann'])\n        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)\n        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}\n        self.img_ids = sorted(self.coco.get_img_ids())\n\n        self.panoptic_anns = mmcv.load(ann_file['panoptic_ann'])\n\n        self.stuff_ids = [\n            k['trainid'] for k in self.panoptic_anns['categories']\n            if k['isthing'] == 0\n        ]\n\n        self.thing_ids = [\n            k['trainid'] for k in self.panoptic_anns['categories']\n            if k['isthing'] == 1\n        ]\n\n        assert self.thing_ids == self.cat_ids\n\n        self.seg2stuff_ids = {\n            i + 1: stuff_id\n            for i, stuff_id in enumerate(self.stuff_ids)\n        }\n\n        self.seg2stuff_ids.update({0: 0})\n\n        self.ins2thing_ids = {\n            i: thing_id\n            for i, thing_id in enumerate(self.thing_ids)\n        }\n\n        data_infos = []\n        total_ann_ids = []\n        for i in self.img_ids:\n            info = self.coco.load_imgs([i])[0]\n            info['filename'] = info['file_name']\n            data_infos.append(info)\n            ann_ids = self.coco.get_ann_ids(img_ids=[i])\n            total_ann_ids.extend(ann_ids)\n        assert len(set(total_ann_ids)) == len(\n            total_ann_ids), f\"Annotation ids in '{ann_file}' are not unique!\"\n        return data_infos\n\n    def _filter_imgs(self, min_size=32):\n        \"\"\"Filter images too small or without ground truths.\"\"\"\n        valid_inds = []\n        # obtain images that contain annotation\n        ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())\n        # obtain images that contain annotations of the required categories\n        ids_in_cat = set()\n        for i, class_id in enumerate(self.cat_ids):\n            ids_in_cat |= set(self.coco.cat_img_map[class_id])\n        # merge the image id sets of the two conditions and use the merged set\n        # to filter out images if self.filter_empty_gt=True\n        ids_in_cat &= ids_with_ann\n\n        valid_img_ids = []\n        for i, img_info in enumerate(self.data_infos):\n            img_id = img_info['id']\n            ann_ids = self.coco.getAnnIds(imgIds=[img_id])\n            ann_info = self.coco.loadAnns(ann_ids)\n            all_iscrowd = all([_['iscrowd'] for _ in ann_info])\n            if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat\n                                         or all_iscrowd):\n                continue\n            if min(img_info['width'], img_info['height']) >= min_size:\n                valid_inds.append(i)\n                valid_img_ids.append(img_id)\n        self.img_ids = valid_img_ids\n        return valid_inds\n\n    def prepare_train_img(self, idx):\n        \"\"\"Get training data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training data and annotation after pipeline with new keys \\\n                introduced by pipeline.\n        \"\"\"\n\n        img_info = self.data_infos[idx]\n        ann_info = self.get_ann_info(idx)\n        results = [dict(img_info=img_info, ann_info=ann_info)]\n\n        iid = img_info['id']\n        # self.offsets = [-1, 1] for Cityscapes_VPS\n        offsets = self.offsets.copy()\n        # random sampling of future or past 5-th frame [-1, 1]\n        while True:\n            m = np.random.choice(offsets)\n            ref_iid = iid + m\n            if ref_iid in self.img_ids and self.check_whether_has_correspondence(ref_iid, iid):\n                break\n            offsets.remove(m)\n            # If all offset values fail, return None.\n            if len(offsets) == 0:\n                return None\n        # Reference image: information, annotations\n        ref_iid = iid + m\n\n        ref_img_info = self.iid2_img_infos[ref_iid]\n        ref_ann_info = self.get_ref_ann_info_by_iid(ref_iid, ref_img_info)\n        results.append(dict(img_info=ref_img_info, ann_info=ref_ann_info))\n\n        if self.proposals is not None:\n            results['proposals'] = self.proposals[idx]\n\n        self.pre_pipeline(results)\n\n        return self.pipeline(results)\n\n    def check_whether_has_correspondence(self, ref_iid, iid):\n        ref_img_info = self.iid2_img_infos[ref_iid]\n        ref_ann_info = self.get_ref_ann_info_by_iid(ref_iid, ref_img_info)\n\n        img_info = self.iid2_img_infos[iid]\n        ann_info = self.get_ref_ann_info_by_iid(iid, img_info)\n        nomatch = self.check_match(ref_ann_info, ann_info)\n        if nomatch:  # no match\n            return False\n        else:\n            return True\n\n    def check_match(self, ref_ann_info, ann_info):\n        ref_ids = ref_ann_info['instance_ids'].tolist()\n        gt_ids = ann_info['instance_ids'].tolist()\n        gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n        nomatch = (np.array(gt_pids) == -1).all()\n        return nomatch\n\n    def prepare_test_img(self, idx):\n        \"\"\"Get testing data  after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Testing data after pipeline with new keys introduced by \\\n                pipeline.\n        \"\"\"\n\n        img_info = self.data_infos[idx]\n        prev_img_info = self.data_infos[idx - 1] if idx % (self.nframes_span_test) > 0 else img_info\n        img_info['ref_id'] = prev_img_info['id'] - 1\n        img_info['ref_filename'] = prev_img_info['file_name']\n        results = dict(img_info=img_info)\n        if self.proposals is not None:\n            results['proposals'] = self.proposals[idx]\n        self.pre_test_pipeline(results)\n        return self.pipeline(results)\n\n    def pre_pipeline(self, results):\n        \"\"\"Prepare results dict for pipeline.\"\"\"\n        for result in results:\n            result['img_prefix'] = self.img_prefix\n            result['seg_prefix'] = self.seg_prefix\n            result['proposal_file'] = self.proposal_file\n            result['bbox_fields'] = []\n            result['mask_fields'] = []\n            result['seg_fields'] = []\n            seg_filename = result['ann_info']['seg_map'].replace('leftImg8bit', 'gtFine_color').\\\n                replace('newImg8bit', 'final_mask')\n\n            result['ann_info']['seg_map'] = seg_filename\n\n    def pre_test_pipeline(self, results):\n        results['img_prefix'] = self.img_prefix\n        results['seg_prefix'] = self.seg_prefix\n        results['ref_prefix'] = self.ref_prefix\n        results['proposal_file'] = self.proposal_file\n        results['bbox_fields'] = []\n        results['mask_fields'] = []\n        results['ref_bbox_fields'] = []\n        results['ref_mask_fields'] = []\n\n    def _parse_ann_info(self, img_info, ann_info):\n        \"\"\"Parse bbox and mask annotation.\n\n        Args:\n            img_info (dict): Image info of an image.\n            ann_info (list[dict]): Annotation info of an image.\n\n        Returns:\n            dict: A dict containing the following keys: bboxes, \\\n                bboxes_ignore, labels, masks, seg_map. \\\n                \"masks\" are already decoded into binary masks.\n        \"\"\"\n        gt_bboxes = []\n        gt_labels = []\n        gt_bboxes_ignore = []\n        gt_masks_ann = []\n        gt_obj_ids = []\n\n        for i, ann in enumerate(ann_info):\n            if ann.get('ignore', False):\n                continue\n            x1, y1, w, h = ann['bbox']\n            if ann['area'] <= 0 or w < 1 or h < 1:\n                continue\n            if ann['category_id'] not in self.cat_ids:\n                continue\n            bbox = [x1, y1, x1 + w, y1 + h]\n            if ann.get('iscrowd', False):\n                gt_bboxes_ignore.append(bbox)\n            else:\n                gt_bboxes.append(bbox)\n                gt_labels.append(self.cat2label[ann['category_id']])\n                gt_masks_ann.append(ann['segmentation'])\n                gt_obj_ids.append(ann['inst_id'])\n\n        if gt_bboxes:\n            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)\n            gt_labels = np.array(gt_labels, dtype=np.int64)\n            gt_obj_ids = np.array(gt_obj_ids, dtype=np.int64)\n        else:\n            gt_bboxes = np.zeros((0, 4), dtype=np.float32)\n            gt_labels = np.array([], dtype=np.int64)\n            gt_obj_ids = np.array([], dtype=np.int64)\n\n        if gt_bboxes_ignore:\n            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)\n        else:\n            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)\n\n        seg_map = img_info['filename'].replace('jpg', 'png')\n\n        ann = dict(\n            bboxes=gt_bboxes,\n            labels=gt_labels,\n            bboxes_ignore=gt_bboxes_ignore,\n            masks=gt_masks_ann,\n            instance_ids=gt_obj_ids,\n            seg_map=seg_map)\n\n        return ann\n\n    def get_ref_ann_info_by_iid(self, img_id, ref_img_info):\n        ann_ids = self.ref_coco.getAnnIds(imgIds=[img_id])\n        ann_info = self.ref_coco.loadAnns(ann_ids)\n        return self._parse_ann_info(ref_img_info, ann_info)\n\n    def _panoptic2json(self, results, outfile_prefix):\n        panoptic_json_results = []\n        mmcv.mkdir_or_exist(outfile_prefix)\n        for idx in range(len(self)):\n            img_id = self.img_ids[idx]\n            panoptic = results[idx]\n            png_string, segments_info = panoptic\n            data = dict()\n            # hack\n            # To match the corresponding ids for panoptic segmentation prediction\n            data['image_id'] = self.data_infos[idx]['file_name'].split(\"/\")[-1].split(\".\")[0][:-12]\n\n            for segment_info in segments_info:\n                isthing = segment_info.pop('isthing')\n                cat_id = segment_info['category_id']\n                if isthing is True:\n                    segment_info['category_id'] = self.ins2thing_ids[cat_id]\n                else:\n                    segment_info['category_id'] = self.seg2stuff_ids[cat_id]\n\n            png_path = self.data_infos[idx]['file_name'].replace(\n                '.jpg', '.png')\n            # hack: to save all the images into one folder\n            png_path = png_path.split(\"/\")[-1]\n            png_save_path = osp.join(outfile_prefix, png_path)\n\n            data['file_name'] = png_path\n\n            with open(png_save_path, 'wb') as f:\n                f.write(png_string)\n            data['segments_info'] = segments_info\n            panoptic_json_results.append(data)\n        return panoptic_json_results\n\n    def results2json(self, results, outfile_prefix):\n        \"\"\"Dump the detection results to a COCO style json file.\n\n        There are 3 types of results: proposals, bbox predictions, mask\n        predictions, and they have different data types. This method will\n        automatically recognize the type, and dump them to json files.\n\n        Args:\n            results (list[list | tuple | ndarray]): Testing results of the\n                dataset.\n            outfile_prefix (str): The filename prefix of the json files. If the\n                prefix is \"somepath/xxx\", the json files will be named\n                \"somepath/xxx.bbox.json\", \"somepath/xxx.segm.json\",\n                \"somepath/xxx.proposal.json\".\n\n        Returns:\n            dict[str: str]: Possible keys are \"bbox\", \"segm\", \"proposal\", and \\\n                values are corresponding filenames.\n        \"\"\"\n        result_files = dict()\n        if isinstance(results[0], list):\n            json_results = self._det2json(results)\n            result_files['bbox'] = f'{outfile_prefix}.bbox.json'\n            result_files['proposal'] = f'{outfile_prefix}.bbox.json'\n            mmcv.dump(json_results, result_files['bbox'])\n        elif isinstance(results[0], tuple):\n            if len(results[0]) == 3:  # dump the panoptic\n                instance_segm_results = []\n                panoptic_results = []\n                for idx in range(len(self)):\n                    det, seg, panoptic = results[idx]\n                    instance_segm_results.append([det, seg])\n                    panoptic_results.append(panoptic)\n                panoptic_json = dict()\n                panoptic_json['annotations'] = self._panoptic2json(\n                    panoptic_results, outfile_prefix)\n                result_files['panoptic'] = f'{outfile_prefix}.panoptic.json'\n                mmcv.dump(panoptic_json, result_files['panoptic'])\n            else:\n                instance_segm_results = results\n            json_results = self._segm2json(instance_segm_results)\n            result_files['bbox'] = f'{outfile_prefix}.bbox.json'\n            result_files['proposal'] = f'{outfile_prefix}.bbox.json'\n            result_files['segm'] = f'{outfile_prefix}.segm.json'\n            mmcv.dump(json_results[0], result_files['bbox'])\n            mmcv.dump(json_results[1], result_files['segm'])\n        elif isinstance(results[0], np.ndarray):\n            json_results = self._proposal2json(results)\n            result_files['proposal'] = f'{outfile_prefix}.proposal.json'\n            mmcv.dump(json_results, result_files['proposal'])\n        else:\n            raise TypeError('invalid type of results')\n        return result_files\n\n    def results2txt(self, results, outfile_prefix):\n        \"\"\"Dump the detection results to a txt file.\n\n        Args:\n            results (list[list | tuple]): Testing results of the\n                dataset.\n            outfile_prefix (str): The filename prefix of the json files.\n                If the prefix is \"somepath/xxx\",\n                the txt files will be named \"somepath/xxx.txt\".\n\n        Returns:\n            list[str]: Result txt files which contains corresponding \\\n                instance segmentation images.\n        \"\"\"\n        try:\n            import cityscapesscripts.helpers.labels as CSLabels\n        except ImportError:\n            raise ImportError('Please run \"pip install citscapesscripts\" to '\n                              'install cityscapesscripts first.')\n        result_files = []\n        os.makedirs(outfile_prefix, exist_ok=True)\n        prog_bar = mmcv.ProgressBar(len(self))\n        for idx in range(len(self)):\n            result = results[idx]\n            filename = self.data_infos[idx]['filename']\n            basename = osp.splitext(osp.basename(filename))[0]\n            pred_txt = osp.join(outfile_prefix, basename + '_pred.txt')\n\n            bbox_result, segm_result = result\n            bboxes = np.vstack(bbox_result)\n            # segm results\n            if isinstance(segm_result, tuple):\n                # Some detectors use different scores for bbox and mask,\n                # like Mask Scoring R-CNN. Score of segm will be used instead\n                # of bbox score.\n                segms = mmcv.concat_list(segm_result[0])\n                mask_score = segm_result[1]\n            else:\n                # use bbox score for mask score\n                segms = mmcv.concat_list(segm_result)\n                mask_score = [bbox[-1] for bbox in bboxes]\n            labels = [\n                np.full(bbox.shape[0], i, dtype=np.int32)\n                for i, bbox in enumerate(bbox_result)\n            ]\n            labels = np.concatenate(labels)\n\n            assert len(bboxes) == len(segms) == len(labels)\n            num_instances = len(bboxes)\n            prog_bar.update()\n            with open(pred_txt, 'w') as fout:\n                for i in range(num_instances):\n                    pred_class = labels[i]\n                    classes = self.CLASSES[pred_class]\n                    class_id = CSLabels.name2label[classes].id\n                    score = mask_score[i]\n                    mask = maskUtils.decode(segms[i]).astype(np.uint8)\n                    png_filename = osp.join(outfile_prefix,\n                                            basename + f'_{i}_{classes}.png')\n                    mmcv.imwrite(mask, png_filename)\n                    fout.write(f'{osp.basename(png_filename)} {class_id} '\n                               f'{score}\\n')\n            result_files.append(pred_txt)\n\n        return result_files\n\n    def format_results(self, results, jsonfile_prefix=None, **kwargs):\n        \"\"\"Format the results to json (standard format for COCO evaluation).\n\n        Args:\n            results (list[tuple | numpy.ndarray]): Testing results of the\n                dataset.\n            jsonfile_prefix (str | None): The prefix of json files. It includes\n                the file path and the prefix of filename, e.g., \"a/b/prefix\".\n                If not specified, a temp file will be created. Default: None.\n\n        Returns:\n            tuple: (result_files, tmp_dir), result_files is a dict containing \\\n                the json filepaths, tmp_dir is the temporal directory created \\\n                for saving json files when jsonfile_prefix is not specified.\n        \"\"\"\n        assert isinstance(results, list), 'results must be a list'\n        assert len(results) == len(self), (\n            'The length of results is not equal to the dataset len: {} != {}'.\n            format(len(results), len(self)))\n\n        if jsonfile_prefix is None:\n            tmp_dir = tempfile.TemporaryDirectory()\n            jsonfile_prefix = osp.join(tmp_dir.name, 'results')\n        else:\n            tmp_dir = None\n        result_files = self.results2json(results, jsonfile_prefix)\n        return result_files, tmp_dir\n\n    def evaluate(self,\n                 results,\n                 metric='bbox',\n                 logger=None,\n                 outfile_prefix=None,\n                 classwise=False,\n                 proposal_nums=(100, 300, 1000),\n                 iou_thrs=np.arange(0.5, 0.96, 0.05),\n                 metric_items = None):\n        \"\"\"Evaluation in Cityscapes/COCO protocol.\n\n        Args:\n            results (list[list | tuple]): Testing results of the dataset.\n            metric (str | list[str]): Metrics to be evaluated. Options are\n                'bbox', 'segm', 'proposal', 'proposal_fast'.\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n            outfile_prefix (str | None): The prefix of output file. It includes\n                the file path and the prefix of filename, e.g., \"a/b/prefix\".\n                If results are evaluated with COCO protocol, it would be the\n                prefix of output json file. For example, the metric is 'bbox'\n                and 'segm', then json files would be \"a/b/prefix.bbox.json\" and\n                \"a/b/prefix.segm.json\".\n                If results are evaluated with cityscapes protocol, it would be\n                the prefix of output txt/png files. The output files would be\n                png images under folder \"a/b/prefix/xxx/\" and the file name of\n                images would be written into a txt file\n                \"a/b/prefix/xxx_pred.txt\", where \"xxx\" is the video name of\n                cityscapes. If not specified, a temp file will be created.\n                Default: None.\n            classwise (bool): Whether to evaluating the AP for each class.\n            proposal_nums (Sequence[int]): Proposal number used for evaluating\n                recalls, such as recall@100, recall@1000.\n                Default: (100, 300, 1000).\n            iou_thrs (Sequence[float]): IoU threshold used for evaluating\n                recalls. If set to a list, the average recall of all IoUs will\n                also be computed. Default: 0.5.\n\n        Returns:\n            dict[str, float]: COCO style evaluation metric or cityscapes mAP \\\n                and AP@50.\n        \"\"\"\n        eval_results = dict()\n\n        metrics = metric.copy() if isinstance(metric, list) else [metric]\n        allowed_metrics = [\n            'bbox', 'segm', 'cityscapes', 'panoptic'\n        ]\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported')\n\n        if 'cityscapes' in metrics:\n            eval_results.update(\n                self._evaluate_cityscapes(results, outfile_prefix, logger))\n            metrics.remove('cityscapes')\n\n        if iou_thrs is None:\n            iou_thrs = np.linspace(\n                .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)\n        if metric_items is not None:\n            if not isinstance(metric_items, list):\n                metric_items = [metric_items]\n\n        result_files, tmp_dir = self.format_results(results, outfile_prefix)\n\n        eval_results = OrderedDict()\n        cocoGt = self.coco\n        for metric in metrics:\n            msg = f'Evaluating {metric}...'\n            if logger is None:\n                msg = '\\n' + msg\n            print_log(msg, logger=logger)\n\n            if metric == 'proposal_fast':\n                ar = self.fast_eval_recall(\n                    results, proposal_nums, iou_thrs, logger='silent')\n                log_msg = []\n                for i, num in enumerate(proposal_nums):\n                    eval_results[f'AR@{num}'] = ar[i]\n                    log_msg.append(f'\\nAR@{num}\\t{ar[i]:.4f}')\n                log_msg = ''.join(log_msg)\n                print_log(log_msg, logger=logger)\n                continue\n\n            if metric == 'panoptic':\n                from panopticapi.evaluation import pq_compute\n                # print(\"pred folder\", result_files['panoptic'].split('.')[0])\n                with contextlib.redirect_stdout(io.StringIO()):\n                    pq_res = pq_compute(\n                        self.ann_file['panoptic_ann'],\n                        result_files['panoptic'],\n                        gt_folder=self.seg_prefix,\n                        pred_folder=result_files['panoptic'].split('.')[0])\n                results = parse_pq_results(pq_res)\n                for k, v in results.items():\n                    eval_results[f'{metric}_{k}'] = f'{float(v):0.3f}'\n                print_log(\n                    'Panoptic Evaluation Results:\\n' +\n                    _print_panoptic_results(pq_res),\n                    logger=logger)\n                continue\n\n            iou_type = 'bbox' if metric == 'proposal' else metric\n            if metric not in result_files:\n                raise KeyError(f'{metric} is not in results')\n            try:\n                predictions = mmcv.load(result_files[metric])\n                if iou_type == 'segm':\n                    # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331  # noqa\n                    # When evaluating mask AP, if the results contain bbox,\n                    # cocoapi will use the box area instead of the mask area\n                    # for calculating the instance area. Though the overall AP\n                    # is not affected, this leads to different small, medium,\n                    # and large mask AP results.\n                    for x in predictions:\n                        x.pop('bbox')\n                cocoDt = cocoGt.loadRes(predictions)\n            except IndexError:\n                print_log(\n                    'The testing results of the whole dataset is empty.',\n                    logger=logger,\n                    level=logging.ERROR)\n                break\n\n            cocoEval = COCOeval(cocoGt, cocoDt, iou_type)\n            cocoEval.params.catIds = self.cat_ids\n            cocoEval.params.imgIds = self.img_ids\n            cocoEval.params.maxDets = list(proposal_nums)\n            cocoEval.params.iouThrs = iou_thrs\n            # mapping of cocoEval.stats\n            coco_metric_names = {\n                'mAP': 0,\n                'mAP_50': 1,\n                'mAP_75': 2,\n                'mAP_s': 3,\n                'mAP_m': 4,\n                'mAP_l': 5,\n                'AR@100': 6,\n                'AR@300': 7,\n                'AR@1000': 8,\n                'AR_s@1000': 9,\n                'AR_m@1000': 10,\n                'AR_l@1000': 11\n            }\n            if metric_items is not None:\n                for metric_item in metric_items:\n                    if metric_item not in coco_metric_names:\n                        raise KeyError(\n                            f'metric item {metric_item} is not supported')\n\n            if metric == 'proposal':\n                cocoEval.params.useCats = 0\n                cocoEval.evaluate()\n                cocoEval.accumulate()\n                cocoEval.summarize()\n                if metric_items is None:\n                    metric_items = [\n                        'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',\n                        'AR_m@1000', 'AR_l@1000'\n                    ]\n\n                for item in metric_items:\n                    val = float(\n                        f'{cocoEval.stats[coco_metric_names[item]]:.3f}')\n                    eval_results[item] = val\n            else:\n                cocoEval.evaluate()\n                cocoEval.accumulate()\n                cocoEval.summarize()\n                if classwise:  # Compute per-category AP\n                    # Compute per-category AP\n                    # from https://github.com/facebookresearch/detectron2/\n                    precisions = cocoEval.eval['precision']\n                    # precision: (iou, recall, cls, area range, max dets)\n                    assert len(self.cat_ids) == precisions.shape[2]\n\n                    results_per_category = []\n                    for idx, catId in enumerate(self.cat_ids):\n                        # area range index 0: all area ranges\n                        # max dets index -1: typically 100 per image\n                        nm = self.coco.loadCats(catId)[0]\n                        precision = precisions[:, :, idx, 0, -1]\n                        precision = precision[precision > -1]\n                        if precision.size:\n                            ap = np.mean(precision)\n                        else:\n                            ap = float('nan')\n                        results_per_category.append(\n                            (f'{nm[\"name\"]}', f'{float(ap):0.3f}'))\n\n                    num_columns = min(6, len(results_per_category) * 2)\n                    results_flatten = list(\n                        itertools.chain(*results_per_category))\n                    headers = ['category', 'AP'] * (num_columns // 2)\n                    results_2d = itertools.zip_longest(*[\n                        results_flatten[i::num_columns]\n                        for i in range(num_columns)\n                    ])\n                    table_data = [headers]\n                    table_data += [result for result in results_2d]\n                    table = AsciiTable(table_data)\n                    print_log('\\n' + table.table, logger=logger)\n\n                if metric_items is None:\n                    metric_items = [\n                        'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'\n                    ]\n\n                for metric_item in metric_items:\n                    key = f'{metric}_{metric_item}'\n                    val = float(\n                        f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'\n                    )\n                    eval_results[key] = val\n                ap = cocoEval.stats[:6]\n                eval_results[f'{metric}_mAP_copypaste'] = (\n                    f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '\n                    f'{ap[4]:.3f} {ap[5]:.3f}')\n\n        if tmp_dir is not None:\n            tmp_dir.cleanup()\n        return eval_results\n\n    def _evaluate_cityscapes(self, results, txtfile_prefix, logger):\n        \"\"\"Evaluation in Cityscapes protocol.\n\n        Args:\n            results (list): Testing results of the dataset.\n            txtfile_prefix (str | None): The prefix of output txt file\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n\n        Returns:\n            dict[str: float]: Cityscapes evaluation results, contains 'mAP' \\\n                and 'AP@50'.\n        \"\"\"\n\n        try:\n            import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval  # noqa\n        except ImportError:\n            raise ImportError('Please run \"pip install citscapesscripts\" to '\n                              'install cityscapesscripts first.')\n        msg = 'Evaluating in Cityscapes style'\n        if logger is None:\n            msg = '\\n' + msg\n        print_log(msg, logger=logger)\n\n        result_files, tmp_dir = self.format_results(results, txtfile_prefix)\n\n        if tmp_dir is None:\n            result_dir = osp.join(txtfile_prefix, 'results')\n        else:\n            result_dir = osp.join(tmp_dir.name, 'results')\n\n        eval_results = OrderedDict()\n        print_log(f'Evaluating results under {result_dir} ...', logger=logger)\n\n        # set global states in cityscapes evaluation API\n        CSEval.args.cityscapesPath = os.path.join(self.img_prefix, '../..')\n        CSEval.args.predictionPath = os.path.abspath(result_dir)\n        CSEval.args.predictionWalk = None\n        CSEval.args.JSONOutput = False\n        CSEval.args.colorized = False\n        CSEval.args.gtInstancesFile = os.path.join(result_dir,\n                                                   'gtInstances.json')\n        CSEval.args.groundTruthSearch = os.path.join(\n            self.img_prefix.replace('leftImg8bit', 'gtFine'),\n            '*/*_gtFine_instanceIds.png')\n\n        groundTruthImgList = glob.glob(CSEval.args.groundTruthSearch)\n        assert len(groundTruthImgList), 'Cannot find ground truth images' \\\n            f' in {CSEval.args.groundTruthSearch}.'\n        predictionImgList = []\n        for gt in groundTruthImgList:\n            predictionImgList.append(CSEval.getPrediction(gt, CSEval.args))\n        CSEval_results = CSEval.evaluateImgLists(predictionImgList,\n                                                 groundTruthImgList,\n                                                 CSEval.args)['averages']\n\n        eval_results['mAP'] = CSEval_results['allAp']\n        eval_results['AP@50'] = CSEval_results['allAp50%']\n        if tmp_dir is not None:\n            tmp_dir.cleanup()\n        return eval_results"
  },
  {
    "path": "external/coco_panoptic.py",
    "content": "import contextlib\nimport io\nimport itertools\nimport logging\nimport tempfile\nimport os.path as osp\nfrom collections import OrderedDict\n\nimport mmcv\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.coco import CocoDataset\nfrom mmdet.datasets.api_wrappers import COCO, COCOeval\nfrom terminaltables import AsciiTable\n\n\n@DATASETS.register_module()\nclass CocoPanopticDatasetCustom(CocoDataset):\n\n    def load_annotations(self, ann_file):\n        \"\"\"Load annotation from COCO style annotation file.\n\n        Args:\n            ann_file (str): Path of annotation file.\n\n        Returns:\n            list[dict]: Annotation info from COCO api.\n        \"\"\"\n\n        self.coco = COCO(ann_file['ins_ann'])\n        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)\n        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}\n        self.img_ids = sorted(self.coco.get_img_ids())\n\n        self.panoptic_anns = mmcv.load(ann_file['panoptic_ann'])\n\n        self.stuff_ids = [\n            k['id'] for k in self.panoptic_anns['categories']\n            if k['isthing'] == 0\n        ]\n        self.thing_ids = [\n            k['id'] for k in self.panoptic_anns['categories']\n            if k['isthing'] == 1\n        ]\n\n        assert self.thing_ids == self.cat_ids\n\n        self.seg2stuff_ids = {\n            i + 1: stuff_id\n            for i, stuff_id in enumerate(self.stuff_ids)\n        }\n        self.seg2stuff_ids.update({0: 0})\n\n        self.ins2thing_ids = {\n            i: thing_id\n            for i, thing_id in enumerate(self.thing_ids)\n        }\n\n        data_infos = []\n        total_ann_ids = []\n        for i in self.img_ids:\n            info = self.coco.load_imgs([i])[0]\n            info['filename'] = info['file_name']\n            data_infos.append(info)\n            ann_ids = self.coco.get_ann_ids(img_ids=[i])\n            total_ann_ids.extend(ann_ids)\n        assert len(set(total_ann_ids)) == len(\n            total_ann_ids), f\"Annotation ids in '{ann_file}' are not unique!\"\n        return data_infos\n\n    def get_ann_info(self, idx):\n        \"\"\"Get COCO annotation by index.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Annotation info of specified index.\n        \"\"\"\n\n        img_id = self.data_infos[idx]['id']\n        ann_ids = self.coco.get_ann_ids(img_ids=[img_id])\n        ann_info = self.coco.load_anns(ann_ids)\n        return self._parse_ann_info(self.data_infos[idx], ann_info)\n\n    def get_cat_ids(self, idx):\n        \"\"\"Get COCO category ids by index.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            list[int]: All categories in the image of specified index.\n        \"\"\"\n\n        img_id = self.data_infos[idx]['id']\n        ann_ids = self.coco.get_ann_ids(img_ids=[img_id])\n        ann_info = self.coco.load_anns(ann_ids)\n        return [ann['category_id'] for ann in ann_info]\n\n    def _parse_ann_info(self, img_info, ann_info):\n        \"\"\"Parse bbox and mask annotation.\n\n        Args:\n            ann_info (list[dict]): Annotation info of an image.\n            with_mask (bool): Whether to parse mask annotations.\n\n        Returns:\n            dict: A dict containing the following keys: bboxes, bboxes_ignore,\\\n                labels, masks, seg_map. \"masks\" are raw annotations and not \\\n                decoded into binary masks.\n        \"\"\"\n        gt_bboxes = []\n        gt_labels = []\n        gt_bboxes_ignore = []\n        gt_masks_ann = []\n        for i, ann in enumerate(ann_info):\n            if ann.get('ignore', False):\n                continue\n            x1, y1, w, h = ann['bbox']\n            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))\n            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))\n            if inter_w * inter_h == 0:\n                continue\n            if ann['area'] <= 0 or w < 1 or h < 1:\n                continue\n            if ann['category_id'] not in self.cat_ids:\n                continue\n            bbox = [x1, y1, x1 + w, y1 + h]\n            if ann.get('iscrowd', False):\n                gt_bboxes_ignore.append(bbox)\n            else:\n                gt_bboxes.append(bbox)\n                gt_labels.append(self.cat2label[ann['category_id']])\n                gt_masks_ann.append(ann.get('segmentation', None))\n\n        if gt_bboxes:\n            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)\n            gt_labels = np.array(gt_labels, dtype=np.int64)\n        else:\n            gt_bboxes = np.zeros((0, 4), dtype=np.float32)\n            gt_labels = np.array([], dtype=np.int64)\n\n        if gt_bboxes_ignore:\n            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)\n        else:\n            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)\n\n        seg_map = img_info['filename'].replace('jpg', 'png')\n\n        ann = dict(\n            bboxes=gt_bboxes,\n            labels=gt_labels,\n            bboxes_ignore=gt_bboxes_ignore,\n            masks=gt_masks_ann,\n            seg_map=seg_map)\n\n        return ann\n\n    def _panoptic2json(self, results, outfile_prefix):\n        panoptic_json_results = []\n        mmcv.mkdir_or_exist(outfile_prefix)\n        for idx in range(len(self)):\n            img_id = self.img_ids[idx]\n            panoptic = results[idx]\n            png_string, segments_info = panoptic\n            data = dict()\n            data['image_id'] = img_id\n            for segment_info in segments_info:\n                isthing = segment_info.pop('isthing')\n                cat_id = segment_info['category_id']\n                if isthing is True:\n                    segment_info['category_id'] = self.ins2thing_ids[cat_id]\n                else:\n                    segment_info['category_id'] = self.seg2stuff_ids[cat_id]\n\n            png_path = self.data_infos[idx]['file_name'].replace(\n                '.jpg', '.png')\n            png_save_path = osp.join(outfile_prefix, png_path)\n            data['file_name'] = png_path\n            with open(png_save_path, 'wb') as f:\n                f.write(png_string)\n            data['segments_info'] = segments_info\n            panoptic_json_results.append(data)\n        return panoptic_json_results\n\n    def results2json(self, results, outfile_prefix):\n        \"\"\"Dump the detection results to a COCO style json file.\n\n        There are 3 types of results: proposals, bbox predictions, mask\n        predictions, and they have different data types. This method will\n        automatically recognize the type, and dump them to json files.\n\n        Args:\n            results (list[list | tuple | ndarray]): Testing results of the\n                dataset.\n            outfile_prefix (str): The filename prefix of the json files. If the\n                prefix is \"somepath/xxx\", the json files will be named\n                \"somepath/xxx.bbox.json\", \"somepath/xxx.segm.json\",\n                \"somepath/xxx.proposal.json\".\n\n        Returns:\n            dict[str: str]: Possible keys are \"bbox\", \"segm\", \"proposal\", and \\\n                values are corresponding filenames.\n        \"\"\"\n        result_files = dict()\n        if isinstance(results[0], list):\n            json_results = self._det2json(results)\n            result_files['bbox'] = f'{outfile_prefix}.bbox.json'\n            result_files['proposal'] = f'{outfile_prefix}.bbox.json'\n            mmcv.dump(json_results, result_files['bbox'])\n        elif isinstance(results[0], tuple):\n            if len(results[0]) == 3:\n                instance_segm_results = []\n                panoptic_results = []\n                for idx in range(len(self)):\n                    det, seg, panoptic = results[idx]\n                    instance_segm_results.append([det, seg])\n                    panoptic_results.append(panoptic)\n                panoptic_json = dict()\n                panoptic_json['annotations'] = self._panoptic2json(\n                    panoptic_results, outfile_prefix)\n                result_files['panoptic'] = f'{outfile_prefix}.panoptic.json'\n                mmcv.dump(panoptic_json, result_files['panoptic'])\n            else:\n                instance_segm_results = results\n            json_results = self._segm2json(instance_segm_results)\n            result_files['bbox'] = f'{outfile_prefix}.bbox.json'\n            result_files['proposal'] = f'{outfile_prefix}.bbox.json'\n            result_files['segm'] = f'{outfile_prefix}.segm.json'\n            mmcv.dump(json_results[0], result_files['bbox'])\n            mmcv.dump(json_results[1], result_files['segm'])\n        elif isinstance(results[0], np.ndarray):\n            json_results = self._proposal2json(results)\n            result_files['proposal'] = f'{outfile_prefix}.proposal.json'\n            mmcv.dump(json_results, result_files['proposal'])\n        else:\n            raise TypeError('invalid type of results')\n        return result_files\n\n    def format_results(self, results, jsonfile_prefix=None, **kwargs):\n        \"\"\"Format the results to json (standard format for COCO evaluation).\n\n        Args:\n            results (list[tuple | numpy.ndarray]): Testing results of the\n                dataset.\n            jsonfile_prefix (str | None): The prefix of json files. It includes\n                the file path and the prefix of filename, e.g., \"a/b/prefix\".\n                If not specified, a temp file will be created. Default: None.\n\n        Returns:\n            tuple: (result_files, tmp_dir), result_files is a dict containing \\\n                the json filepaths, tmp_dir is the temporal directory created \\\n                for saving json files when jsonfile_prefix is not specified.\n        \"\"\"\n        assert isinstance(results, list), 'results must be a list'\n        assert len(results) == len(self), (\n            'The length of results is not equal to the dataset len: {} != {}'.\n            format(len(results), len(self)))\n\n        if jsonfile_prefix is None:\n            tmp_dir = tempfile.TemporaryDirectory()\n            jsonfile_prefix = osp.join(tmp_dir.name, 'results')\n        else:\n            tmp_dir = None\n        result_files = self.results2json(results, jsonfile_prefix)\n        return result_files, tmp_dir\n\n    def evaluate(self,\n                 results,\n                 metric='bbox',\n                 logger=None,\n                 jsonfile_prefix=None,\n                 classwise=False,\n                 proposal_nums=(100, 300, 1000),\n                 iou_thrs=None,\n                 metric_items=None):\n        \"\"\"Evaluation in COCO protocol.\n\n        Args:\n            results (list[list | tuple]): Testing results of the dataset.\n            metric (str | list[str]): Metrics to be evaluated. Options are\n                'bbox', 'segm', 'proposal', 'proposal_fast'.\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n            jsonfile_prefix (str | None): The prefix of json files. It includes\n                the file path and the prefix of filename, e.g., \"a/b/prefix\".\n                If not specified, a temp file will be created. Default: None.\n            classwise (bool): Whether to evaluating the AP for each class.\n            proposal_nums (Sequence[int]): Proposal number used for evaluating\n                recalls, such as recall@100, recall@1000.\n                Default: (100, 300, 1000).\n            iou_thrs (Sequence[float], optional): IoU threshold used for\n                evaluating recalls/mAPs. If set to a list, the average of all\n                IoUs will also be computed. If not specified, [0.50, 0.55,\n                0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.\n                Default: None.\n            metric_items (list[str] | str, optional): Metric items that will\n                be returned. If not specified, ``['AR@100', 'AR@300',\n                'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be\n                used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',\n                'mAP_s', 'mAP_m', 'mAP_l']`` will be used when\n                ``metric=='bbox' or metric=='segm'``.\n\n        Returns:\n            dict[str, float]: COCO style evaluation metric.\n        \"\"\"\n\n        metrics = metric if isinstance(metric, list) else [metric]\n        allowed_metrics = [\n            'bbox', 'segm', 'proposal', 'proposal_fast', 'panoptic'\n        ]\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported')\n        if iou_thrs is None:\n            iou_thrs = np.linspace(\n                .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)\n        if metric_items is not None:\n            if not isinstance(metric_items, list):\n                metric_items = [metric_items]\n\n        result_files, tmp_dir = self.format_results(results, jsonfile_prefix)\n\n        eval_results = OrderedDict()\n        cocoGt = self.coco\n        for metric in metrics:\n            msg = f'Evaluating {metric}...'\n            if logger is None:\n                msg = '\\n' + msg\n            print_log(msg, logger=logger)\n\n            if metric == 'proposal_fast':\n                ar = self.fast_eval_recall(\n                    results, proposal_nums, iou_thrs, logger='silent')\n                log_msg = []\n                for i, num in enumerate(proposal_nums):\n                    eval_results[f'AR@{num}'] = ar[i]\n                    log_msg.append(f'\\nAR@{num}\\t{ar[i]:.4f}')\n                log_msg = ''.join(log_msg)\n                print_log(log_msg, logger=logger)\n                continue\n\n            if metric == 'panoptic':\n                from panopticapi.evaluation import pq_compute\n                with contextlib.redirect_stdout(io.StringIO()):\n                    pq_res = pq_compute(\n                        self.ann_file['panoptic_ann'],\n                        result_files['panoptic'],\n                        gt_folder=self.seg_prefix,\n                        pred_folder=result_files['panoptic'].split('.')[0])\n                results = parse_pq_results(pq_res)\n                for k, v in results.items():\n                    eval_results[f'{metric}_{k}'] = f'{float(v):0.3f}'\n                print_log(\n                    'Panoptic Evaluation Results:\\n' +\n                    _print_panoptic_results(pq_res),\n                    logger=logger)\n                continue\n\n            iou_type = 'bbox' if metric == 'proposal' else metric\n            if metric not in result_files:\n                raise KeyError(f'{metric} is not in results')\n            try:\n                predictions = mmcv.load(result_files[metric])\n                if iou_type == 'segm':\n                    # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331  # noqa\n                    # When evaluating mask AP, if the results contain bbox,\n                    # cocoapi will use the box area instead of the mask area\n                    # for calculating the instance area. Though the overall AP\n                    # is not affected, this leads to different small, medium,\n                    # and large mask AP results.\n                    for x in predictions:\n                        x.pop('bbox')\n                cocoDt = cocoGt.loadRes(predictions)\n            except IndexError:\n                print_log(\n                    'The testing results of the whole dataset is empty.',\n                    logger=logger,\n                    level=logging.ERROR)\n                break\n\n            cocoEval = COCOeval(cocoGt, cocoDt, iou_type)\n            cocoEval.params.catIds = self.cat_ids\n            cocoEval.params.imgIds = self.img_ids\n            cocoEval.params.maxDets = list(proposal_nums)\n            cocoEval.params.iouThrs = iou_thrs\n            # mapping of cocoEval.stats\n            coco_metric_names = {\n                'mAP': 0,\n                'mAP_50': 1,\n                'mAP_75': 2,\n                'mAP_s': 3,\n                'mAP_m': 4,\n                'mAP_l': 5,\n                'AR@100': 6,\n                'AR@300': 7,\n                'AR@1000': 8,\n                'AR_s@1000': 9,\n                'AR_m@1000': 10,\n                'AR_l@1000': 11\n            }\n            if metric_items is not None:\n                for metric_item in metric_items:\n                    if metric_item not in coco_metric_names:\n                        raise KeyError(\n                            f'metric item {metric_item} is not supported')\n\n            if metric == 'proposal':\n                cocoEval.params.useCats = 0\n                cocoEval.evaluate()\n                cocoEval.accumulate()\n                cocoEval.summarize()\n                if metric_items is None:\n                    metric_items = [\n                        'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',\n                        'AR_m@1000', 'AR_l@1000'\n                    ]\n\n                for item in metric_items:\n                    val = float(\n                        f'{cocoEval.stats[coco_metric_names[item]]:.3f}')\n                    eval_results[item] = val\n            else:\n                cocoEval.evaluate()\n                cocoEval.accumulate()\n                cocoEval.summarize()\n                if classwise:  # Compute per-category AP\n                    # Compute per-category AP\n                    # from https://github.com/facebookresearch/detectron2/\n                    precisions = cocoEval.eval['precision']\n                    # precision: (iou, recall, cls, area range, max dets)\n                    assert len(self.cat_ids) == precisions.shape[2]\n\n                    results_per_category = []\n                    for idx, catId in enumerate(self.cat_ids):\n                        # area range index 0: all area ranges\n                        # max dets index -1: typically 100 per image\n                        nm = self.coco.loadCats(catId)[0]\n                        precision = precisions[:, :, idx, 0, -1]\n                        precision = precision[precision > -1]\n                        if precision.size:\n                            ap = np.mean(precision)\n                        else:\n                            ap = float('nan')\n                        results_per_category.append(\n                            (f'{nm[\"name\"]}', f'{float(ap):0.3f}'))\n\n                    num_columns = min(6, len(results_per_category) * 2)\n                    results_flatten = list(\n                        itertools.chain(*results_per_category))\n                    headers = ['category', 'AP'] * (num_columns // 2)\n                    results_2d = itertools.zip_longest(*[\n                        results_flatten[i::num_columns]\n                        for i in range(num_columns)\n                    ])\n                    table_data = [headers]\n                    table_data += [result for result in results_2d]\n                    table = AsciiTable(table_data)\n                    print_log('\\n' + table.table, logger=logger)\n\n                if metric_items is None:\n                    metric_items = [\n                        'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'\n                    ]\n\n                for metric_item in metric_items:\n                    key = f'{metric}_{metric_item}'\n                    val = float(\n                        f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'\n                    )\n                    eval_results[key] = val\n                ap = cocoEval.stats[:6]\n                eval_results[f'{metric}_mAP_copypaste'] = (\n                    f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '\n                    f'{ap[4]:.3f} {ap[5]:.3f}')\n\n        if tmp_dir is not None:\n            tmp_dir.cleanup()\n        return eval_results\n\n\ndef parse_pq_results(pq_res):\n    res = dict()\n    res['PQ'] = 100 * pq_res['All']['pq']\n    res['SQ'] = 100 * pq_res['All']['sq']\n    res['RQ'] = 100 * pq_res['All']['rq']\n    res['PQ_th'] = 100 * pq_res['Things']['pq']\n    res['SQ_th'] = 100 * pq_res['Things']['sq']\n    res['RQ_th'] = 100 * pq_res['Things']['rq']\n    res['PQ_st'] = 100 * pq_res['Stuff']['pq']\n    res['SQ_st'] = 100 * pq_res['Stuff']['sq']\n    res['RQ_st'] = 100 * pq_res['Stuff']['rq']\n    return res\n\n\ndef _print_panoptic_results(pq_res):\n    headers = ['', 'PQ', 'SQ', 'RQ', 'categories']\n    data = [headers]\n    for name in ['All', 'Things', 'Stuff']:\n        numbers = [\n            f'{(pq_res[name][k] * 100):0.3f}' for k in ['pq', 'sq', 'rq']\n        ]\n        row = [name] + numbers + [pq_res[name]['n']]\n        data.append(row)\n    table = AsciiTable(data)\n    return table.table\n"
  },
  {
    "path": "external/dataset/dvps_pipelines/__init__.py",
    "content": ""
  },
  {
    "path": "external/dataset/dvps_pipelines/loading.py",
    "content": "import mmcv\nimport numpy as np\nfrom mmdet.core import BitmapMasks\nfrom mmdet.datasets.builder import PIPELINES\n\n\ndef bitmasks2bboxes(bitmasks):\n    bitmasks_array = bitmasks.masks\n    boxes = np.zeros((bitmasks_array.shape[0], 4), dtype=np.float32)\n    x_any = np.any(bitmasks_array, axis=1)\n    y_any = np.any(bitmasks_array, axis=2)\n    for idx in range(bitmasks_array.shape[0]):\n        x = np.where(x_any[idx, :])[0]\n        y = np.where(y_any[idx, :])[0]\n        if len(x) > 0 and len(y) > 0:\n            boxes[idx, :] = np.array((x[0], y[0], x[-1], y[-1]), dtype=np.float32)\n    return boxes\n\n\n@PIPELINES.register_module()\nclass LoadImgDirect:\n    \"\"\"Go ahead and just load image\n    \"\"\"\n\n    def __init__(self,\n                 to_float32=False,\n                 color_type='color'):\n        self.to_float32 = to_float32\n        self.color_type = color_type\n\n    def __call__(self, results):\n        \"\"\"Call functions to load image and get image meta information.\n\n        Args:\n            results (dict): Result dict requires \"img\" which is the img path.\n\n        Returns:\n            dict: The dict contains loaded image and meta information.\n            'img' : img\n            'img_shape' : img_shape\n            'ori_shape' : original shape\n            'img_fields' : the img fields\n        \"\"\"\n        img = mmcv.imread(results['img'], channel_order='rgb', flag=self.color_type)\n        if self.to_float32:\n            img = img.astype(np.float32)\n\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['ori_shape'] = img.shape\n        results['img_fields'] = ['img']\n        return results\n\n    def __repr__(self):\n        repr_str = (f'{self.__class__.__name__}('\n                    f'to_float32={self.to_float32}, '\n                    f\"color_type='{self.color_type}', \")\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass LoadMultiImagesDirect(LoadImgDirect):\n    \"\"\"Load multi images from file.\n    Please refer to `mmdet.datasets.pipelines.loading.py:LoadImageFromFile`\n    for detailed docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in `results`, call the call function of\n        `LoadImageFromFile` to load image.\n        Args:\n            results (list[dict]): List of dict from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains loaded image.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass LoadAnnotationsDirect:\n    \"\"\"Go ahead and just load image\n    \"\"\"\n\n    def __init__(self,\n                 with_depth=True,\n                 divisor: int = 1000,\n                 cherry_pick=False,\n                 cherry=None,\n                 viper=False,\n                 vipseg=False\n                 ):\n        self.with_depth = with_depth\n        self.panseg_divisor = divisor\n        self.cherry_pick = cherry_pick\n        self.cherry = cherry\n        self.viper = viper\n        self.vipseg=vipseg\n        if self.vipseg:\n            self.panseg_divisor = 1000\n\n    def __call__(self, results):\n        \"\"\"Call functions to load image and get image meta information.\n\n        Args:\n            results (dict): Result dict requires \"img\" which is the img path.\n\n        Returns:\n            dict: The dict contains loaded image and meta information.\n            'depth_fields' : the depth fields for supporting depth aug\n        \"\"\"\n\n        if self.with_depth:\n            depth = mmcv.imread(results['depth'], flag='unchanged').astype(np.float32) / 256.\n            del results['depth']\n            depth[depth >= 80.] = 80.\n            results['gt_depth'] = depth\n            results['depth_fields'] = ['gt_depth']\n\n        local_divisor = 10000\n        if self.panseg_divisor == 0:\n            # The seperate file to store class id and inst id\n            gt_semantic_seg = mmcv.imread(results['ann_class'], flag='unchanged').astype(np.float32)\n            inst_map = mmcv.imread(results['ann_inst'], flag='unchanged').astype(np.float32)\n            ps_id = gt_semantic_seg * local_divisor + inst_map\n            del results['ann_class']\n            del results['ann_inst']\n        elif self.panseg_divisor == -1:\n            # KITTI step mode which means the panseg is stored with RGB\n            id_map = mmcv.imread(results['ann'], flag='color', channel_order='rgb')\n            gt_semantic_seg = id_map[..., 0].astype(np.float32)\n            inst_map = id_map[..., 1].astype(np.float32) * 256 + id_map[..., 2].astype(np.float32)\n            ps_id = gt_semantic_seg * local_divisor + inst_map\n            del results['ann']\n        else:\n            ps_id = mmcv.imread(results['ann'], flag='unchanged').astype(np.float32)\n            if self.vipseg:\n                ps_id = results['pre_hook'](ps_id)\n                del results['pre_hook']\n            # This is for viper\n            if self.viper or self.vipseg:\n                ps_id[ps_id < 1000] *= 1000\n            del results['ann']\n            gt_semantic_seg = ps_id // self.panseg_divisor\n\n        if self.viper:\n            gt_semantic_seg[gt_semantic_seg >= results['thing_upper']] = results['no_obj_class']\n        results['gt_semantic_seg'] = gt_semantic_seg.astype(np.int)\n        results['seg_fields'] = ['gt_semantic_seg']\n\n        classes = []\n        masks = []\n        instance_ids = []\n        no_obj_class = results['no_obj_class']\n        for pan_seg_id in np.unique(ps_id):\n            classes.append(pan_seg_id // self.panseg_divisor if self.panseg_divisor > 0\n                           else pan_seg_id // local_divisor)\n            masks.append((ps_id == pan_seg_id).astype(np.int))\n            instance_ids.append(pan_seg_id)\n        gt_labels = np.stack(classes).astype(np.int)\n        gt_instance_ids = np.stack(instance_ids).astype(np.int)\n        gt_masks = BitmapMasks(masks, height=results['img_shape'][0], width=results['img_shape'][1])\n        # check the sanity of gt_masks\n        verify = np.sum(gt_masks.masks.astype(np.int), axis=0)\n        assert (verify == np.ones(gt_masks.masks.shape[-2:], dtype=verify.dtype)).all()\n        # now delete the no_obj_class\n        gt_masks.masks = np.delete(gt_masks.masks, gt_labels == no_obj_class, axis=0)\n        gt_instance_ids = np.delete(gt_instance_ids, gt_labels == no_obj_class)\n        gt_labels = np.delete(gt_labels, gt_labels == no_obj_class)\n        if results['is_instance_only'] and not self.cherry_pick:\n            gt_masks.masks = np.delete(\n                gt_masks.masks,\n                (gt_labels >= results['thing_upper']) | (gt_labels < results['thing_lower']),\n                axis=0\n            )\n            gt_instance_ids = np.delete(\n                gt_instance_ids,\n                (gt_labels >= results['thing_upper']) | (gt_labels < results['thing_lower'])\n            )\n            gt_labels = np.delete(\n                gt_labels,\n                (gt_labels >= results['thing_upper']) | (gt_labels < results['thing_lower'])\n            )\n            gt_labels -= results['thing_lower']\n        elif results['is_instance_only'] and self.cherry_pick:\n            gt_masks.masks = np.delete(\n                gt_masks.masks,\n                list(map(lambda x: x not in self.cherry, gt_labels)),\n                axis=0\n            )\n            gt_instance_ids = np.delete(\n                gt_instance_ids,\n                list(map(lambda x: x not in self.cherry, gt_labels)),\n            )\n            gt_labels = np.delete(\n                gt_labels,\n                list(map(lambda x: x not in self.cherry, gt_labels)),\n            )\n            gt_labels = np.array(list(map(lambda x: self.cherry.index(x), gt_labels))) if len(gt_labels) > 0 else []\n\n        if len(gt_labels) == 0:\n            return None\n\n        results['gt_labels'] = gt_labels\n        results['gt_masks'] = gt_masks\n        results['gt_instance_ids'] = gt_instance_ids\n        results['mask_fields'] = ['gt_masks']\n\n        # generate boxes\n        boxes = bitmasks2bboxes(gt_masks)\n        results['gt_bboxes'] = boxes\n        results['bbox_fields'] = ['gt_bboxes']\n        return results\n\n\n@PIPELINES.register_module()\nclass LoadMultiAnnotationsDirect(LoadAnnotationsDirect):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            if _results is None:\n                return None\n            outs.append(_results)\n        return outs\n"
  },
  {
    "path": "external/dataset/dvps_pipelines/transforms.py",
    "content": "import mmcv\nimport numpy as np\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import Resize, RandomFlip, Pad, Normalize\n\n\n@PIPELINES.register_module()\nclass ResizeWithDepth(Resize):\n    \"\"\"This subclass of Resize is to support depth resize\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        assert kwargs['keep_ratio']\n        super().__init__(*args, **kwargs)\n\n    def _resize_depth(self, results):\n        \"\"\"Resize depth with ``results['scale']``\"\"\"\n        # Although depth is not discrete, we use nearest to match the segmentation\n        for key in results.get('depth_fields', []):\n            if self.keep_ratio:\n                results[key] = mmcv.imrescale(\n                    results[key],\n                    results['scale'],\n                    interpolation='nearest',\n                    backend=self.backend)\n            else:\n                results[key] = mmcv.imresize(\n                    results[key],\n                    results['scale'],\n                    interpolation='nearest',\n                    backend=self.backend)\n            results[key] /= results['scale_factor'].mean()\n\n    def __call__(self, results):\n        super().__call__(results)\n        self._resize_depth(results)\n        return results\n\n\n@PIPELINES.register_module()\nclass SeqResizeWithDepth(ResizeWithDepth):\n    \"\"\"Resize images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Resize` for\n    detailed docstring.\n    Args:\n        share_params (bool): If True, share the resize parameters for all\n            images. Defaults to True.\n    \"\"\"\n\n    def __init__(self, share_params=True, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.share_params = share_params\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Resize` to resize\n        image and corresponding annotations.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains resized results,\n            'img_shape', 'pad_shape', 'scale_factor', 'keep_ratio' keys\n            are added into result dict.\n        \"\"\"\n        outs, scale = [], None\n        for i, _results in enumerate(results):\n            if self.share_params and i > 0:\n                _results['scale'] = scale\n            _results = super().__call__(_results)\n            if self.share_params and i == 0:\n                scale = _results['scale']\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass RandomFlipWithDepth(RandomFlip):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        super().__call__(results)\n        if results['flip']:\n            for key in results.get('depth_fields', []):\n                results[key] = mmcv.imflip(\n                    results[key], direction=results['flip_direction'])\n        return results\n\n\n@PIPELINES.register_module()\nclass SeqFlipWithDepth(RandomFlipWithDepth):\n    \"\"\"Randomly flip for images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:RandomFlip` for\n    detailed docstring.\n    Args:\n        share_params (bool): If True, share the flip parameters for all images.\n            Defaults to True.\n    \"\"\"\n\n    def __init__(self, share_params=True, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.share_params = share_params\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call `RandomFlip` to randomly flip image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains flipped results, 'flip',\n            'flip_direction' keys are added into the dict.\n        \"\"\"\n        if self.share_params:\n            if isinstance(self.direction, list):\n                # None means non-flip\n                direction_list = self.direction + [None]\n            else:\n                # None means non-flip\n                direction_list = [self.direction, None]\n\n            if isinstance(self.flip_ratio, list):\n                non_flip_ratio = 1 - sum(self.flip_ratio)\n                flip_ratio_list = self.flip_ratio + [non_flip_ratio]\n            else:\n                non_flip_ratio = 1 - self.flip_ratio\n                # exclude non-flip\n                single_ratio = self.flip_ratio / (len(direction_list) - 1)\n                flip_ratio_list = [single_ratio] * (len(direction_list) -\n                                                    1) + [non_flip_ratio]\n\n            cur_dir = np.random.choice(direction_list, p=flip_ratio_list)\n            flip = cur_dir is not None\n            flip_direction = cur_dir\n\n            for _results in results:\n                _results['flip'] = flip\n                _results['flip_direction'] = flip_direction\n\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqRandomCropWithDepth(object):\n    \"\"\"Sequentially random crop the images & bboxes & masks.\n    The absolute `crop_size` is sampled based on `crop_type` and `image_size`,\n    then the cropped results are generated.\n    Args:\n        crop_size (tuple): The relative ratio or absolute pixels of\n            height and width.\n        allow_negative_crop (bool, optional): Whether to allow a crop that does\n            not contain any bbox area. Default False.\n        share_params (bool, optional): Whether share the cropping parameters\n            for the images.\n        bbox_clip_border (bool, optional): Whether clip the objects outside\n            the border of the image. Defaults to True.\n    Note:\n        - If the image is smaller than the absolute crop size, return the\n            original image.\n        - The keys for bboxes, labels and masks must be aligned. That is,\n          `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and\n          `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and\n          `gt_masks_ignore`.\n        - If the crop does not contain any gt-bbox region and\n          `allow_negative_crop` is set to False, skip this image.\n    \"\"\"\n\n    def __init__(self,\n                 crop_size,\n                 allow_negative_crop=False,\n                 share_params=False,\n                 bbox_clip_border=True,\n                 check_id_match=True,\n                 ):\n        assert crop_size is None or (crop_size[0] > 0 and crop_size[1] > 0)\n        self.crop_size = crop_size\n        self.allow_negative_crop = allow_negative_crop\n        self.share_params = share_params\n        self.bbox_clip_border = bbox_clip_border\n        self.check_id_match = check_id_match\n        # The key correspondence from bboxes to labels and masks.\n        self.bbox2label = {\n            'gt_bboxes': ['gt_labels', 'gt_instance_ids'],\n            'gt_bboxes_ignore': ['gt_labels_ignore', 'gt_instance_ids_ignore']\n        }\n        self.bbox2mask = {\n            'gt_bboxes': 'gt_masks',\n            'gt_bboxes_ignore': 'gt_masks_ignore'\n        }\n\n    def get_offsets(self, img):\n        \"\"\"Random generate the offsets for cropping.\"\"\"\n        margin_h = max(img.shape[0] - self.crop_size[0], 0)\n        margin_w = max(img.shape[1] - self.crop_size[1], 0)\n        offset_h = np.random.randint(0, margin_h + 1)\n        offset_w = np.random.randint(0, margin_w + 1)\n        return offset_h, offset_w\n\n    def random_crop(self, results, offsets=None):\n        \"\"\"Call function to randomly crop images, bounding boxes, masks,\n        semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n            offsets (tuple, optional): Pre-defined offsets for cropping.\n                Default to None.\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n            updated according to crop size.\n        \"\"\"\n        # Only supporting img\n        assert results['img_fields'] == ['img']\n        img = results['img']\n        if offsets is not None:\n            offset_h, offset_w = offsets\n        else:\n            offset_h, offset_w = self.get_offsets(img)\n        results['crop_offsets'] = (offset_h, offset_w)\n        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]\n        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]\n\n        # crop the image\n        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]\n        img_shape = img.shape\n        results['img'] = img\n        results['img_shape'] = img_shape\n\n        # crop bboxes accordingly and clip to the image boundary\n        for key in results.get('bbox_fields', []):\n            # e.g. gt_bboxes and gt_bboxes_ignore\n            bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],\n                                   dtype=np.float32)\n            bboxes = results[key] - bbox_offset\n            if self.bbox_clip_border:\n                bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])\n                bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])\n            valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (\n                    bboxes[:, 3] > bboxes[:, 1])\n            # If the crop does not contain any gt-bbox area and\n            # self.allow_negative_crop is False, skip this image.\n            if (key == 'gt_bboxes' and not valid_inds.any()\n                    and not self.allow_negative_crop):\n                return None\n            results[key] = bboxes[valid_inds, :]\n            # label fields. e.g. gt_labels and gt_labels_ignore\n            label_keys = self.bbox2label.get(key)\n            for label_key in label_keys:\n                if label_key in results:\n                    results[label_key] = results[label_key][valid_inds]\n\n            # mask fields, e.g. gt_masks and gt_masks_ignore\n            mask_key = self.bbox2mask.get(key)\n            if mask_key in results:\n                results[mask_key] = results[mask_key][\n                    valid_inds.nonzero()[0]].crop(\n                    np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))\n\n        # crop semantic seg\n        for key in results.get('seg_fields', []):\n            results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]\n\n        # crop depth\n        for key in results.get('depth_fields', []):\n            results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]\n\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to sequentially randomly crop images, bounding boxes,\n        masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n            updated according to crop size.\n        \"\"\"\n        if self.share_params and self.crop_size is not None:\n            offsets = self.get_offsets(results[0]['img'])\n        else:\n            offsets = None\n\n        if self.crop_size is not None:\n            outs = []\n            for _results in results:\n                _results = self.random_crop(_results, offsets)\n                if _results is None:\n                    return None\n                outs.append(_results)\n        else:\n            outs = []\n            for _results in results:\n                outs.append(_results)\n\n        if len(outs) == 2 and self.check_id_match:\n            ref_result, result = outs[1], outs[0]\n            if self.check_match(ref_result, result):\n                return None\n\n        return outs\n\n    def check_match(self, ref_results, results):\n        ref_ids = ref_results['gt_instance_ids'].tolist()\n        gt_ids = results['gt_instance_ids'].tolist()\n        gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n        nomatch = (np.array(gt_pids) == -1).all()\n        return nomatch\n\n\n@PIPELINES.register_module()\nclass PadWithDepth(Pad):\n\n    def _pad_depth(self, results):\n        \"\"\"Pad depth according to\n        ``results['pad_shape']``.\"\"\"\n        for key in results.get('depth_fields', []):\n            results[key] = mmcv.impad(\n                results[key], shape=results['pad_shape'][:2], pad_val=0)\n\n    # the original pad sem_seg does not consider the no_obj_class with value except for 0\n    #\n    def _pad_seg(self, results):\n        \"\"\"Pad semantic segmentation map according to\n        ``results['pad_shape']``.\"\"\"\n        no_obj_class = results['no_obj_class']\n        for key in results.get('seg_fields', []):\n            results[key] = mmcv.impad(\n                results[key],\n                shape=results['pad_shape'][:2],\n                pad_val=no_obj_class)\n\n    def __call__(self, results):\n        \"\"\"Call function to pad images, masks, semantic segmentation maps.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        self._pad_img(results)\n        self._pad_masks(results)\n        self._pad_seg(results)\n        self._pad_depth(results)\n        return results\n\n\n@PIPELINES.register_module()\nclass SeqPadWithDepth(PadWithDepth):\n    \"\"\"Pad images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Pad` for detailed\n    docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Pad` to pad image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains padding results,\n            'pad_shape', 'pad_fixed_size' and 'pad_size_divisor' keys are\n            added into the dict.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n# There is nothing new from SeqNormalize.\n@PIPELINES.register_module()\nclass SeqNormalizeWithDepth(Normalize):\n    \"\"\"Normalize images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Normalize` for\n    detailed docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Normalize` to\n        normalize image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains normalized results,\n            'img_norm_cfg' key is added into result dict.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n"
  },
  {
    "path": "external/dataset/dvps_pipelines/tricks.py",
    "content": "import numpy as np\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import AutoAugment\n\n\n@PIPELINES.register_module()\nclass SeqAutoAug(AutoAugment):\n    \"\"\"\n    Auto augmentation a sequence.\n    \"\"\"\n    def __init__(self, policies):\n        super().__init__(policies=policies)\n\n    def __call__(self, results):\n        transform = np.random.choice(self.transforms)\n        outs = []\n        for _results in results:\n            out = transform(_results)\n            outs.append(out)\n        return outs\n"
  },
  {
    "path": "external/dataset/forecasting_pipelines/__init__.py",
    "content": ""
  },
  {
    "path": "external/dataset/forecasting_pipelines/loading.py",
    "content": "import mmcv\nimport numpy as np\nfrom mmdet.core import BitmapMasks\n\nfrom mmdet.datasets.builder import PIPELINES\n\n\ndef bitmasks2bboxes(bitmasks):\n    bitmasks_array = bitmasks.masks\n    boxes = np.zeros((bitmasks_array.shape[0], 4), dtype=np.float32)\n    x_any = np.any(bitmasks_array, axis=1)\n    y_any = np.any(bitmasks_array, axis=2)\n    for idx in range(bitmasks_array.shape[0]):\n        x = np.where(x_any[idx, :])[0]\n        y = np.where(y_any[idx, :])[0]\n        if len(x) > 0 and len(y) > 0:\n            boxes[idx, :] = np.array((x[0], y[0], x[-1], y[-1]), dtype=np.float32)\n    return boxes\n\n\n@PIPELINES.register_module()\nclass LoadMultiImagesFromFile:\n    \"\"\"Load an image from file.\n    Required keys are \"img_prefix\" and \"img_info\" (a dict that must contain the\n    key \"filename\"). Added or updated keys are \"filename\", \"img\", \"img_shape\",\n    \"ori_shape\" (same as `img_shape`), \"pad_shape\" (same as `img_shape`),\n    \"scale_factor\" (1.0) and \"img_norm_cfg\" (means=0 and stds=1).\n    Args:\n        to_float32 (bool): Whether to convert the loaded image to a float32\n            numpy array. If set to False, the loaded image is an uint8 array.\n            Defaults to False.\n        color_type (str): The flag argument for :func:`mmcv.imfrombytes`.\n            Defaults to 'color'.\n        file_client_args (dict): Arguments to instantiate a FileClient.\n            See :class:`mmcv.fileio.FileClient` for details.\n            Defaults to ``dict(backend='disk')``.\n    \"\"\"\n\n    def __init__(self,\n                 to_float32=False,\n                 color_type='color',\n                 file_client_args=dict(backend='disk')):\n        self.to_float32 = to_float32\n        self.color_type = color_type\n        self.file_client_args = file_client_args.copy()\n        self.file_client = None\n\n    def __call__(self, results):\n        \"\"\"Call functions to load image and get image meta information.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded image and meta information.\n        \"\"\"\n\n        if self.file_client is None:\n            self.file_client = mmcv.FileClient(**self.file_client_args)\n\n        filenames = results['img_info']['filename']\n        imgs = []\n        for filename in filenames:\n            img_bytes = self.file_client.get(filename)\n            img = mmcv.imfrombytes(img_bytes, flag=self.color_type)\n            if self.to_float32:\n                img = img.astype(np.float32)\n            imgs.append(img)\n        img = np.concatenate(imgs, axis=-1)\n\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['ori_shape'] = img.shape\n        results['img_fields'].append('img')\n        return results\n\n    def __repr__(self):\n        repr_str = (f'{self.__class__.__name__}('\n                    f'to_float32={self.to_float32}, '\n                    f\"color_type='{self.color_type}', \"\n                    f'file_client_args={self.file_client_args})')\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass LoadAnnotationsInstanceMasks:\n    def __init__(self,\n                 with_mask=True,\n                 with_seg=True,\n                 with_inst=False,\n                 file_client_args=dict(backend='disk')):\n        self.with_mask = with_mask\n        self.with_seg = with_seg\n        self.with_inst = with_inst\n        self.file_client_args = file_client_args.copy()\n        self.file_client = None\n\n    def _load_masks(self, results):\n        \"\"\"Private function to load mask annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded mask annotations.\n                If ``self.poly2mask`` is set ``True``, `gt_mask` will contain\n                :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.\n        \"\"\"\n\n        img_bytes = self.file_client.get(results['ann_info']['inst_map'])\n        inst_mask = mmcv.imfrombytes(img_bytes, flag='unchanged').squeeze()\n        if self.with_inst:\n            results['gt_instance_map'] = inst_mask.copy().astype(int)\n            results['gt_instance_map'][inst_mask < 10000] *= 1000\n        if not self.with_mask:\n            return results\n        masks = []\n        labels = []\n        for inst_id in np.unique(inst_mask):\n            if inst_id >= 10000:\n                masks.append((inst_mask == inst_id).astype(int))\n                labels.append(inst_id // 1000)\n        if len(masks) == 0:\n            return None\n        gt_masks = BitmapMasks(masks, height=inst_mask.shape[0], width=inst_mask.shape[1])\n        results['gt_masks'] = gt_masks\n        results['mask_fields'].append('gt_masks')\n        results['gt_labels'] = np.array(labels)\n\n        boxes = bitmasks2bboxes(gt_masks)\n        results['gt_bboxes'] = boxes\n        results['bbox_fields'].append('gt_bboxes')\n        return results\n\n    def _load_semantic_seg(self, results):\n        \"\"\"Private function to load semantic segmentation annotations.\n        Args:\n            results (dict): Result dict from :obj:`dataset`.\n        Returns:\n            dict: The dict contains loaded semantic segmentation annotations.\n        \"\"\"\n        img_bytes = self.file_client.get(results['ann_info']['seg_map'])\n        results['gt_semantic_seg'] = mmcv.imfrombytes(\n            img_bytes, flag='unchanged').squeeze()\n        results['seg_fields'].append('gt_semantic_seg')\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to load multiple types annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded bounding box, label, mask and\n                semantic segmentation annotations.\n        \"\"\"\n        if self.file_client is None:\n            self.file_client = mmcv.FileClient(**self.file_client_args)\n        if self.with_mask or self.with_inst:\n            results = self._load_masks(results)\n            if results is None:\n                return None\n        if self.with_seg:\n            results = self._load_semantic_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'with_mask={self.with_mask}, '\n        repr_str += f'with_seg={self.with_seg}, '\n        return repr_str\n"
  },
  {
    "path": "external/dataset/forecasting_pipelines/transforms.py",
    "content": "import mmcv\nimport numpy as np\nimport warnings\nfrom mmdet.datasets import PIPELINES\n\n\n@PIPELINES.register_module()\nclass NormalizeMultiple:\n    \"\"\"Normalize the image.\n\n    Added key is \"img_norm_cfg\".\n\n    Args:\n        mean (sequence): Mean values of 3 channels.\n        std (sequence): Std values of 3 channels.\n        to_rgb (bool): Whether to convert the image from BGR to RGB,\n            default is true.\n    \"\"\"\n\n    def __init__(self, mean, std, to_rgb=True):\n        self.mean = np.array(mean, dtype=np.float32)\n        self.std = np.array(std, dtype=np.float32)\n        self.to_rgb = to_rgb\n\n    def __call__(self, results):\n        \"\"\"Call function to normalize images.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Normalized results, 'img_norm_cfg' key is added into\n                result dict.\n        \"\"\"\n        for key in results.get('img_fields', ['img']):\n            if results[key].shape[-1] > 3:\n                num_3 = results[key].shape[-1]\n                assert num_3 % 3 == 0\n                num_img = num_3 // 3\n                img = np.ones_like(results[key]).astype(np.float32)\n                for i in range(num_img):\n                    img[..., 3 * i:3 * i + 3] = mmcv.imnormalize(\n                        results[key][..., 3 * i:3 * i + 3], self.mean, self.std, self.to_rgb)\n                results[key] = img\n            else:\n                results[key] = mmcv.imnormalize(results[key], self.mean, self.std, self.to_rgb)\n        results['img_norm_cfg'] = dict(\n            mean=self.mean, std=self.std, to_rgb=self.to_rgb)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass PadFutureMMDet:\n    \"\"\"Pad the image & masks & segmentation map.\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n    Added keys are \"pad_shape\", \"pad_fixed_size\", \"pad_size_divisor\",\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_to_square (bool): Whether to pad the image into a square.\n            Currently only used for YOLOX. Default: False.\n        pad_val (dict, optional): A dict for padding value, the default\n            value is `dict(img=0, masks=0, seg=255)`.\n    \"\"\"\n\n    def __init__(self,\n                 size=None,\n                 size_divisor=None,\n                 pad_to_square=False,\n                 pad_val=dict(img=0, masks=0, seg=255)):\n        self.size = size\n        self.size_divisor = size_divisor\n        if isinstance(pad_val, float) or isinstance(pad_val, int):\n            warnings.warn(\n                'pad_val of float type is deprecated now, '\n                f'please use pad_val=dict(img={pad_val}, '\n                f'masks={pad_val}, seg=255) instead.', DeprecationWarning)\n            pad_val = dict(img=pad_val, masks=pad_val, seg=255)\n        assert isinstance(pad_val, dict)\n        self.pad_val = pad_val\n        self.pad_to_square = pad_to_square\n\n        if pad_to_square:\n            assert size is None and size_divisor is None, \\\n                'The size and size_divisor must be None ' \\\n                'when pad2square is True'\n        else:\n            assert size is not None or size_divisor is not None, \\\n                'only one of size and size_divisor should be valid'\n            assert size is None or size_divisor is None\n\n    def _pad_img(self, results):\n        \"\"\"Pad images according to ``self.size``.\"\"\"\n        pad_val = self.pad_val.get('img', 0)\n        for key in results.get('img_fields', ['img']):\n            if self.pad_to_square:\n                max_size = max(results[key].shape[:2])\n                self.size = (max_size, max_size)\n            if self.size is not None:\n                padded_img = mmcv.impad(\n                    results[key], shape=self.size, pad_val=pad_val)\n            elif self.size_divisor is not None:\n                padded_img = mmcv.impad_to_multiple(\n                    results[key], self.size_divisor, pad_val=pad_val)\n            results[key] = padded_img\n        results['pad_shape'] = padded_img.shape\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def _pad_masks(self, results):\n        \"\"\"Pad masks according to ``results['pad_shape']``.\"\"\"\n        pad_shape = results['pad_shape'][:2]\n        pad_val = self.pad_val.get('masks', 0)\n        for key in results.get('mask_fields', []):\n            results[key] = results[key].pad(pad_shape, pad_val=pad_val)\n\n    def _pad_seg(self, results):\n        \"\"\"Pad semantic segmentation map according to\n        ``results['pad_shape']``.\"\"\"\n        pad_val = self.pad_val.get('seg', 255)\n        for key in results.get('seg_fields', []):\n            results[key] = mmcv.impad(\n                results[key], shape=results['pad_shape'][:2], pad_val=pad_val)\n\n    def __call__(self, results):\n        \"\"\"Call function to pad images, masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        self._pad_img(results)\n        self._pad_masks(results)\n        self._pad_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(size={self.size}, '\n        repr_str += f'size_divisor={self.size_divisor}, '\n        repr_str += f'pad_to_square={self.pad_to_square}, '\n        repr_str += f'pad_val={self.pad_val})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass KNetInsAdapter:\n    \"\"\"Adapter that is used to convert city-style instance class-ids\n    to coco-style instance-ids (11-starting to 0-starting)\n    \"\"\"\n\n    def __init__(self, stuff_nums=11):\n        self.stuff_nums = stuff_nums\n\n    def __call__(self, results):\n        \"\"\"Call function to modify gt_labels\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        results['gt_labels'] -= self.stuff_nums\n        return results\n"
  },
  {
    "path": "external/dataset/mIoU.py",
    "content": "import numpy as np\n\n\ndef eval_miou(results, targets, num_classes, ignore_index=255):\n    total_area_intersect = np.zeros((num_classes,), dtype=np.float64)\n    total_area_union = np.zeros((num_classes,), dtype=np.float64)\n    total_area_pred = np.zeros((num_classes,), dtype=np.float64)\n    total_area_label = np.zeros((num_classes,), dtype=np.float64)\n\n    for result, target in zip(results, targets):\n        mask = (target != ignore_index)\n        pred = result[mask]\n        label = target[mask]\n\n        intersect = pred[pred == label]\n        area_intersect, _ = np.histogram(intersect.astype(float), bins=num_classes, range=(0, num_classes - 1))\n        area_pred, _ = np.histogram(pred.astype(float), bins=num_classes, range=(0, num_classes - 1))\n        area_label, _ = np.histogram(label.astype(float), bins=num_classes, range=(0, num_classes - 1))\n        area_union = area_pred + area_label - area_intersect\n\n        total_area_intersect += area_intersect\n        total_area_pred += area_intersect\n        total_area_label += area_label\n        total_area_union += area_union\n\n    iou_per_class = total_area_intersect / total_area_union\n    return iou_per_class\n\n\nif __name__ == '__main__':\n    results = [\n        np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n    ]\n    targets = [\n        np.array([[1, 2, 3], [1, 1, 2], [255, 255, 255]])\n    ]\n    eval_miou(results, targets, 19)\n"
  },
  {
    "path": "external/dataset/pipelines/__init__.py",
    "content": ""
  },
  {
    "path": "external/dataset/pipelines/formatting.py",
    "content": "import numpy as np\nimport torch\nfrom mmcv.parallel import DataContainer as DC\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import to_tensor\n\n\n@PIPELINES.register_module()\nclass ConcatVideoReferences(object):\n    \"\"\"Concat video references.\n\n    If the input list contains at least two dicts, concat the input list of\n    dict to one dict from 2-nd dict of the input list.\n\n    Args:\n        results (list[dict]): List of dict that contain keys such as 'img',\n            'img_metas', 'gt_masks','proposals', 'gt_bboxes',\n            'gt_bboxes_ignore', 'gt_labels','gt_semantic_seg',\n            'gt_instance_ids'.\n\n    Returns:\n        list[dict]: The first dict of outputs is the same as the first\n        dict of `results`. The second dict of outputs concats the\n        dicts in `results[1:]`.\n    \"\"\"\n\n    def __call__(self, results):\n        assert (isinstance(results, list)), 'results must be list'\n        outs = results[:1]\n        for i, result in enumerate(results[1:], 1):\n            if 'img' in result:\n                img = result['img']\n                if len(img.shape) < 3:\n                    img = np.expand_dims(img, -1)\n                if i == 1:\n                    result['img'] = np.expand_dims(img, -1)\n                else:\n                    outs[1]['img'] = np.concatenate(\n                        (outs[1]['img'], np.expand_dims(img, -1)), axis=-1)\n            for key in ['img_metas', 'gt_masks']:\n                if key in result:\n                    if i == 1:\n                        result[key] = [result[key]]\n                    else:\n                        outs[1][key].append(result[key])\n            for key in [\n                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',\n                'gt_instance_ids',\n            ]:\n                if key not in result:\n                    continue\n                value = result[key]\n                if value.ndim == 1:\n                    value = value[:, None]\n                N = value.shape[0]\n                value = np.concatenate((np.full(\n                    (N, 1), i - 1, dtype=np.float32), value),\n                    axis=1)\n                if i == 1:\n                    result[key] = value\n                else:\n                    outs[1][key] = np.concatenate((outs[1][key], value),\n                                                  axis=0)\n            if 'gt_semantic_seg' in result:\n                if i == 1:\n                    result['gt_semantic_seg'] = result['gt_semantic_seg'][...,\n                                                                          None,\n                                                                          None]\n                else:\n                    outs[1]['gt_semantic_seg'] = np.concatenate(\n                        (outs[1]['gt_semantic_seg'],\n                         result['gt_semantic_seg'][..., None, None]),\n                        axis=-1)\n\n            if 'gt_depth' in result:\n                if i == 1:\n                    result['gt_depth'] = result['gt_depth'][...,\n                                                            None,\n                                                            None]\n                else:\n                    outs[1]['gt_depth'] = np.concatenate(\n                        (outs[1]['gt_depth'],\n                         result['gt_depth'][..., None, None]),\n                        axis=-1)\n            if i == 1:\n                outs.append(result)\n        return outs\n\n\n@PIPELINES.register_module()\nclass ConcatVideos(object):\n    \"\"\"Concat video references.\n\n    If the input list contains at least two dicts, concat the input list of\n    dict to one dict from 2-nd dict of the input list.\n\n    Args:\n        results (list[dict]): List of dict that contain keys such as 'img',\n            'img_metas', 'gt_masks','proposals', 'gt_bboxes',\n            'gt_bboxes_ignore', 'gt_labels','gt_semantic_seg',\n            'gt_instance_ids'.\n\n    Returns:\n        list[dict]: The first dict of outputs is the same as the first\n        dict of `results`. The second dict of outputs concats the\n        dicts in `results[1:]`.\n    \"\"\"\n\n    def __call__(self, results):\n        assert (isinstance(results, list)), 'results must be list'\n        outs = results[:1]\n        # outs = []\n        for i, result in enumerate(results[0:], 1):\n            if 'img' in result:\n                img = result['img']\n                if len(img.shape) < 3:\n                    img = np.expand_dims(img, -1)\n                if i == 1:\n                    result['img'] = np.expand_dims(img, -1)\n                else:\n                    outs[1]['img'] = np.concatenate(\n                        (outs[1]['img'], np.expand_dims(img, -1)), axis=-1)\n            for key in ['img_metas', 'gt_masks']:\n                if key in result:\n                    if i == 1:\n                        result[key] = [result[key]]\n                    else:\n                        outs[1][key].append(result[key])\n            for key in [\n                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',\n                'gt_instance_ids'\n            ]:\n                if key not in result:\n                    continue\n                value = result[key]\n                if value.ndim == 1:\n                    value = value[:, None]\n                N = value.shape[0]\n                value = np.concatenate((np.full(\n                    (N, 1), i - 1, dtype=np.float32), value),\n                    axis=1)\n                if i == 1:\n                    result[key] = value\n                else:\n                    outs[1][key] = np.concatenate((outs[1][key], value),\n                                                  axis=0)\n            if 'gt_semantic_seg' in result:\n                if i == 1:\n                    result['gt_semantic_seg'] = result['gt_semantic_seg'][...,\n                                                                          None,\n                                                                          None]\n                else:\n                    outs[1]['gt_semantic_seg'] = np.concatenate(\n                        (outs[1]['gt_semantic_seg'],\n                         result['gt_semantic_seg'][..., None, None]),\n                        axis=-1)\n            if i == 1:\n                outs.append(result)\n        res = []\n        res.append(outs[1])\n        return res\n\n\n@PIPELINES.register_module()\nclass MultiImagesToTensor(object):\n    \"\"\"Multi images to tensor.\n\n    1. Transpose and convert image/multi-images to Tensor.\n    2. Add prefix to every key in the second dict of the inputs. Then, add\n    these keys and corresponding values into the outputs.\n\n    Args:\n        ref_prefix (str): The prefix of key added to the second dict of inputs.\n            Defaults to 'ref'.\n    \"\"\"\n\n    def __init__(self, ref_prefix='ref'):\n        self.ref_prefix = ref_prefix\n\n    def __call__(self, results):\n        \"\"\"Multi images to tensor.\n\n        1. Transpose and convert image/multi-images to Tensor.\n        2. Add prefix to every key in the second dict of the inputs. Then, add\n        these keys and corresponding values into the output dict.\n\n        Args:\n            results (list[dict]): List of two dicts.\n\n        Returns:\n            dict: Each key in the first dict of `results` remains unchanged.\n            Each key in the second dict of `results` adds `self.ref_prefix`\n            as prefix.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = self.images_to_tensor(_results)\n            outs.append(_results)\n\n        data = {}\n        data.update(outs[0])\n        if len(outs) == 2:\n            for k, v in outs[1].items():\n                data[f'{self.ref_prefix}_{k}'] = v\n\n        return data\n\n    def images_to_tensor(self, results):\n        \"\"\"Transpose and convert images/multi-images to Tensor.\"\"\"\n        if 'img' in results:\n            img = results['img']\n            if len(img.shape) == 3:\n                # (H, W, 3) to (3, H, W)\n                img = np.ascontiguousarray(img.transpose(2, 0, 1))\n            else:\n                # (H, W, 3, N) to (N, 3, H, W)\n                img = np.ascontiguousarray(img.transpose(3, 2, 0, 1))\n            results['img'] = to_tensor(img)\n        if 'proposals' in results:\n            results['proposals'] = to_tensor(results['proposals'])\n        if 'img_metas' in results:\n            results['img_metas'] = DC(results['img_metas'], cpu_only=True)\n        return results\n\n\n@PIPELINES.register_module()\nclass SeqDefaultFormatBundle(object):\n    \"\"\"Sequence Default formatting bundle.\n\n    It simplifies the pipeline of formatting common fields, including \"img\",\n    \"img_metas\", \"proposals\", \"gt_bboxes\", \"gt_instance_ids\",\n    \"gt_match_indices\", \"gt_bboxes_ignore\", \"gt_labels\", \"gt_masks\" and\n    \"gt_semantic_seg\". These fields are formatted as follows.\n\n    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)\n    - img_metas: (1) to DataContainer (cpu_only=True)\n    - proposals: (1) to tensor, (2) to DataContainer\n    - gt_bboxes: (1) to tensor, (2) to DataContainer\n    - gt_instance_ids: (1) to tensor, (2) to DataContainer\n    - gt_match_indices: (1) to tensor, (2) to DataContainer\n    - gt_bboxes_ignore: (1) to tensor, (2) to DataContainer\n    - gt_labels: (1) to tensor, (2) to DataContainer\n    - gt_masks: (1) to DataContainer (cpu_only=True)\n    - gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \\\n                       (3) to DataContainer (stack=True)\n\n    Args:\n        ref_prefix (str): The prefix of key added to the second dict of input\n            list. Defaults to 'ref'.\n    \"\"\"\n\n    def __init__(self, ref_prefix='ref'):\n        self.ref_prefix = ref_prefix\n\n    def __call__(self, results):\n        \"\"\"Sequence Default formatting bundle call function.\n\n        Args:\n            results (list[dict]): List of two dicts.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            default bundle. Each key in the second dict of the input list\n            adds `self.ref_prefix` as prefix.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = self.default_format_bundle(_results)\n            outs.append(_results)\n\n        data = {}\n        if self.ref_prefix == 'ref':\n            # origin frames\n            data.update(outs[0])\n            # reference frames\n            if len(outs) == 1:\n                # for k in outs[0]:\n                #     data[f'{self.ref_prefix}_{k}'] = None\n                pass\n            else:\n                for k, v in outs[1].items():\n                    data[f'{self.ref_prefix}_{k}'] = v\n        elif self.ref_prefix is None:\n            # origin frames\n            data.update(outs[0])\n\n        return data\n\n    def default_format_bundle(self, results):\n        \"\"\"Transform and format common fields in results.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            default bundle.\n        \"\"\"\n        if 'img' in results:\n            img = results['img']\n            if len(img.shape) == 3:\n                img = np.ascontiguousarray(img.transpose(2, 0, 1))\n            else:\n                img = np.ascontiguousarray(img.transpose(3, 2, 0, 1))\n            results['img'] = DC(to_tensor(img), stack=True)\n        for key in [\n            'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',\n            'gt_instance_ids', 'gt_match_indices',\n        ]:\n            if key not in results:\n                continue\n            results[key] = DC(to_tensor(results[key]))\n        for key in ['img_metas', 'gt_masks']:\n            if key in results:\n                results[key] = DC(results[key], cpu_only=True)\n        if 'gt_semantic_seg' in results:\n            semantic_seg = results['gt_semantic_seg']\n            if len(semantic_seg.shape) == 2:\n                semantic_seg = semantic_seg[None, ...]\n            else:\n                semantic_seg = np.ascontiguousarray(\n                    semantic_seg.transpose(3, 2, 0, 1))\n            results['gt_semantic_seg'] = DC(\n                to_tensor(semantic_seg), stack=True)\n        if 'gt_depth' in results:\n            gt_depth = results['gt_depth']\n            if len(gt_depth.shape) == 2:\n                gt_depth = gt_depth[None, ...]\n            else:\n                gt_depth = np.ascontiguousarray(\n                    gt_depth.transpose(3, 2, 0, 1))\n            results['gt_depth'] = DC(\n                to_tensor(gt_depth), stack=True)\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\n@PIPELINES.register_module()\nclass VideoCollect(object):\n    \"\"\"Collect data from the loader relevant to the specific task.\n\n    Args:\n        keys (Sequence[str]): Keys of results to be collected in ``data``.\n        meta_keys (Sequence[str]): Meta keys to be converted to\n            ``mmcv.DataContainer`` and collected in ``data[img_metas]``.\n            Defaults to None.\n        default_meta_keys (tuple): Default meta keys. Defaults to ('filename',\n            'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',\n            'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',\n            'frame_id', 'is_video_data').\n    \"\"\"\n\n    def __init__(self,\n                 keys,\n                 meta_keys=None,\n                 reject_empty=False,\n                 num_ref_imgs=0,\n                 # no_obj_class is added for handling non-0  no-obj class\n                 default_meta_keys=('filename', 'ori_filename', 'ori_shape',\n                                    'img_shape', 'pad_shape', 'scale_factor',\n                                    'flip', 'flip_direction', 'img_norm_cfg',\n                                    'video_id',\n                                    'frame_id', 'is_video_data', 'no_obj_class')):\n        self.keys = keys\n        self.meta_keys = default_meta_keys\n        if meta_keys is not None:\n            if isinstance(meta_keys, str):\n                meta_keys = (meta_keys,)\n            else:\n                assert isinstance(meta_keys, tuple), \\\n                    'meta_keys must be str or tuple'\n            self.meta_keys += meta_keys\n\n        self.reject_empty = reject_empty\n        self.num_ref_imgs = num_ref_imgs\n\n    def __call__(self, results):\n        \"\"\"Call function to collect keys in results.\n\n        The keys in ``meta_keys`` and ``default_meta_keys`` will be converted\n        to :obj:mmcv.DataContainer.\n\n        Args:\n            results (list[dict] | dict): List of dict or dict which contains\n                the data to collect.\n\n        Returns:\n            list[dict] | dict: List of dict or dict that contains the\n            following keys:\n\n            - keys in ``self.keys``\n            - ``img_metas``\n        \"\"\"\n        results_is_dict = isinstance(results, dict)\n        if results_is_dict:\n            results = [results]\n        outs = []\n        for _results in results:\n            _results = self._add_default_meta_keys(_results)\n            _results = self._collect_meta_keys(_results)\n            outs.append(_results)\n\n        if results_is_dict:\n            outs[0]['img_metas'] = DC(outs[0]['img_metas'], cpu_only=True)\n\n        if self.reject_empty:\n            if len(results[0]['gt_labels']) == 0:\n                return None\n        if self.num_ref_imgs > 0:\n            if len(results) != self.num_ref_imgs + 1:\n                return None\n        return outs[0] if results_is_dict else outs\n\n    def _collect_meta_keys(self, results):\n        \"\"\"Collect `self.keys` and `self.meta_keys` from `results` (dict).\"\"\"\n        data = {}\n        img_meta = {}\n        for key in self.meta_keys:\n            if key in results:\n                img_meta[key] = results[key]\n            elif key in results['img_info']:\n                img_meta[key] = results['img_info'][key]\n        data['img_metas'] = img_meta\n        for key in self.keys:\n            data[key] = results[key]\n        return data\n\n    def _add_default_meta_keys(self, results):\n        \"\"\"Add default meta keys.\n\n        We set default meta keys including `pad_shape`, `scale_factor` and\n        `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and\n        `Pad` are implemented during the whole pipeline.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            results (dict): Updated result dict contains the data to convert.\n        \"\"\"\n        img = results['img']\n        results.setdefault('pad_shape', img.shape)\n        results.setdefault('scale_factor', 1.0)\n        num_channels = 1 if len(img.shape) < 3 else img.shape[2]\n        results.setdefault(\n            'img_norm_cfg',\n            dict(\n                mean=np.zeros(num_channels, dtype=np.float32),\n                std=np.ones(num_channels, dtype=np.float32),\n                to_rgb=False))\n        return results\n\n\n@PIPELINES.register_module()\nclass ToList(object):\n    \"\"\"Use list to warp each value of the input dict.\n\n    Args:\n        results (dict): Result dict contains the data to convert.\n\n    Returns:\n        dict: Updated result dict contains the data to convert.\n    \"\"\"\n\n    def __call__(self, results):\n        out = {}\n        for k, v in results.items():\n            out[k] = [v]\n        return out\n\n\n@PIPELINES.register_module()\nclass ReIDFormatBundle(object):\n    \"\"\"ReID formatting bundle.\n\n    It first concatenates common fields, then simplifies the pipeline of\n    formatting common fields, including \"img\", and \"gt_label\".\n    These fields are formatted as follows.\n\n    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)\n    - gt_labels: (1) to tensor, (2) to DataContainer\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def __call__(self, results):\n        \"\"\"ReID formatting bundle call function.\n\n        Args:\n            results (list[dict] or dict): List of dicts or dict.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            ReID bundle.\n        \"\"\"\n        inputs = dict()\n        if isinstance(results, list):\n            assert len(results) > 1, \\\n                'the \\'results\\' only have one item, ' \\\n                'please directly use normal pipeline not \\'Seq\\' pipeline.'\n            inputs['img'] = np.stack([_results['img'] for _results in results],\n                                     axis=3)\n            inputs['gt_label'] = np.stack(\n                [_results['gt_label'] for _results in results], axis=0)\n        elif isinstance(results, dict):\n            inputs['img'] = results['img']\n            inputs['gt_label'] = results['gt_label']\n        else:\n            raise TypeError('results must be a list or a dict.')\n        outs = self.reid_format_bundle(inputs)\n\n        return outs\n\n    def reid_format_bundle(self, results):\n        \"\"\"Transform and format gt_label fields in results.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            ReID bundle.\n        \"\"\"\n        for key in results:\n            if key == 'img':\n                img = results[key]\n                if img.ndim == 3:\n                    img = np.ascontiguousarray(img.transpose(2, 0, 1))\n                else:\n                    img = np.ascontiguousarray(img.transpose(3, 2, 0, 1))\n                results['img'] = DC(to_tensor(img), stack=True)\n            elif key == 'gt_label':\n                results[key] = DC(\n                    to_tensor(results[key]), stack=True, pad_dims=None)\n            else:\n                raise KeyError(f'key {key} is not supported')\n        return results\n\n\n@PIPELINES.register_module()\nclass ImageToTensorWithRef(object):\n\n    def __init__(self, keys):\n        self.keys = keys\n\n    def __call__(self, results):\n\n        for key in self.keys:\n            if key in ['ref_img']:\n                if isinstance(results[key], list):\n                    img_ref = []\n                    for img in results[key]:\n                        img = np.ascontiguousarray(img.transpose(2, 0, 1))\n                        img_ref.append(img)\n                    img_ref = np.array(img_ref)\n                    results[key] = to_tensor(img_ref)\n                else:\n                    img = np.ascontiguousarray(results[key].transpose(2, 0, 1))\n                    results[key] = to_tensor(img)\n            else:\n                results[key] = to_tensor(results[key].transpose(2, 0, 1))\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(keys={})'.format(self.keys)\n\n@PIPELINES.register_module()\nclass LabelConsistentChecker:\n    \"\"\"This module is to make the annotations are consistent in each video.\n    \"\"\"\n    def __init__(self, num_frames=5):\n        self.num_frames = num_frames\n\n    def __call__(self, results):\n        ref_gt_instance_ids = results['ref_gt_instance_ids'].data\n        ins_mul_nframe = ref_gt_instance_ids.size(0)\n        if ins_mul_nframe % self.num_frames != 0:\n            return None\n        num_ins = ins_mul_nframe // self.num_frames\n        ins_id_bucket = torch.zeros((num_ins,), dtype=torch.float)\n        for i in range(ins_mul_nframe):\n            frame_cur = i // num_ins\n            ins_cur = i % num_ins\n            if ref_gt_instance_ids[i][0] != frame_cur:\n                return None\n            if frame_cur == 0:\n                ins_id_bucket[ins_cur] = ref_gt_instance_ids[i][1]\n            else:\n                if ref_gt_instance_ids[i][1] != ins_id_bucket[ins_cur]:\n                    return None\n        return results\n\n"
  },
  {
    "path": "external/dataset/pipelines/loading.py",
    "content": "import os.path as osp\nimport numpy as np\n\nimport mmcv\nfrom mmdet.core import BitmapMasks\n\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile\n\n\n@PIPELINES.register_module()\nclass LoadMultiImagesFromFile(LoadImageFromFile):\n    \"\"\"Load multi images from file.\n    Please refer to `mmdet.datasets.pipelines.loading.py:LoadImageFromFile`\n    for detailed docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in `results`, call the call function of\n        `LoadImageFromFile` to load image.\n        Args:\n            results (list[dict]): List of dict from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains loaded image.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqLoadAnnotations(LoadAnnotations):\n    \"\"\"Sequence load annotations.\n    Please refer to `mmdet.datasets.pipelines.loading.py:LoadAnnotations`\n    for detailed docstring.\n    Args:\n        with_track (bool): If True, load instance ids of bboxes.\n    \"\"\"\n\n    def __init__(self, with_track=False, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.with_track = with_track\n\n    def _load_track(self, results):\n        \"\"\"Private function to load label annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            dict: The dict contains loaded label annotations.\n        \"\"\"\n\n        results['gt_instance_ids'] = results['ann_info']['instance_ids'].copy()\n\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `LoadAnnotations`\n        to load annotation.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains loaded annotations, such as\n            bounding boxes, labels, instance ids, masks and semantic\n            segmentation annotations.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            if self.with_track:\n                _results = self._load_track(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass LoadRefImageFromFile(object):\n    \"\"\"\n    Code reading reference frame information.\n    Specific to Cityscapes-VPS, Cityscapes, and VIPER datasets.\n    \"\"\"\n\n    def __init__(self, sample=True, to_float32=False):\n        self.to_float32 = to_float32\n        self.sample = sample\n\n    def __call__(self, results):\n        # requires dirname for ref images\n        assert results['ref_prefix'] is not None, 'ref_prefix must be specified.'\n\n        filename = osp.join(results['img_prefix'],\n                            results['img_info']['filename'])\n        img = mmcv.imread(filename)\n        # if specified by another ref json file.\n        if 'ref_filename' in results['img_info']:\n            ref_filename = osp.join(results['ref_prefix'],\n                                    results['img_info']['ref_filename'])\n            ref_img = mmcv.imread(ref_filename)  # [1024, 2048, 3]\n        else:\n            raise NotImplementedError('We need this implementation.')\n\n        if self.to_float32:\n            img = img.astype(np.float32)\n            ref_img = ref_img.astype(np.float32)\n\n        results['filename'] = filename\n        results['ori_filename'] = results['img_info']['filename']\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['ori_shape'] = img.shape\n        results['ref_img'] = ref_img\n        results['iid'] = results['img_info']['id']\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(to_float32={})'.format(\n            self.to_float32)\n\n\ndef bitmasks2bboxes(bitmasks):\n    bitmasks_array = bitmasks.masks\n    boxes = np.zeros((bitmasks_array.shape[0], 4), dtype=np.float32)\n    x_any = np.any(bitmasks_array, axis=1)\n    y_any = np.any(bitmasks_array, axis=2)\n    for idx in range(bitmasks_array.shape[0]):\n        x = np.where(x_any[idx, :])[0]\n        y = np.where(y_any[idx, :])[0]\n        if len(x) > 0 and len(y) > 0:\n            boxes[idx, :] = np.array((x[0], y[0], x[-1], y[-1]), dtype=np.float32)\n    return boxes\n\n\n@PIPELINES.register_module()\nclass LoadAnnotationsInstanceMasks:\n    def __init__(self,\n                 with_mask=True,\n                 with_seg=True,\n                 with_inst=False,\n                 cherry=None,\n                 file_client_args=dict(backend='disk')):\n        self.with_mask = with_mask\n        self.with_seg = with_seg\n        self.with_inst = with_inst\n        self.file_client_args = file_client_args.copy()\n        self.cherry = cherry\n        self.file_client = None\n\n    def _load_masks(self, results):\n        \"\"\"Private function to load mask annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded mask annotations.\n                If ``self.poly2mask`` is set ``True``, `gt_mask` will contain\n                :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.\n        \"\"\"\n\n        img_bytes = self.file_client.get(results['ann_info']['inst_map'])\n        inst_mask = mmcv.imfrombytes(img_bytes, flag='unchanged').squeeze()\n        if self.with_inst:\n            results['gt_instance_map'] = inst_mask.copy().astype(int)\n            results['gt_instance_map'][inst_mask < 10000] *= 1000\n        if not self.with_mask:\n            return results\n        masks = []\n        labels = []\n        for inst_id in np.unique(inst_mask):\n            if inst_id >= 10000:\n                if self.cherry is not None and not (inst_id // 1000 in self.cherry):\n                    continue\n                masks.append((inst_mask == inst_id).astype(int))\n                labels.append(inst_id // 1000)\n        if len(masks) == 0:\n            return None\n        gt_masks = BitmapMasks(masks, height=inst_mask.shape[0], width=inst_mask.shape[1])\n        results['gt_masks'] = gt_masks\n        results['mask_fields'].append('gt_masks')\n        results['gt_labels'] = np.array(labels)\n\n        boxes = bitmasks2bboxes(gt_masks)\n        results['gt_bboxes'] = boxes\n        results['bbox_fields'].append('gt_bboxes')\n        return results\n\n    def _load_semantic_seg(self, results):\n        \"\"\"Private function to load semantic segmentation annotations.\n        Args:\n            results (dict): Result dict from :obj:`dataset`.\n        Returns:\n            dict: The dict contains loaded semantic segmentation annotations.\n        \"\"\"\n        img_bytes = self.file_client.get(results['ann_info']['seg_map'])\n        results['gt_semantic_seg'] = mmcv.imfrombytes(\n            img_bytes, flag='unchanged').squeeze()\n        results['seg_fields'].append('gt_semantic_seg')\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to load multiple types annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded bounding box, label, mask and\n                semantic segmentation annotations.\n        \"\"\"\n        if self.file_client is None:\n            self.file_client = mmcv.FileClient(**self.file_client_args)\n        if self.with_mask or self.with_inst:\n            results = self._load_masks(results)\n            if results is None:\n                return None\n        if self.with_seg:\n            results = self._load_semantic_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'with_mask={self.with_mask}, '\n        repr_str += f'with_seg={self.with_seg}, '\n        return repr_str\n"
  },
  {
    "path": "external/dataset/pipelines/test_time_aug.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport warnings\n\nimport mmcv\n\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import Compose\n\n\n@PIPELINES.register_module()\nclass MultiScaleFlipAugVideo:\n    \"\"\"Test-time augmentation with multiple scales and flipping.\n    An example configuration is as followed:\n    .. code-block::\n        img_scale=[(1333, 400), (1333, 800)],\n        flip=True,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ]\n    After MultiScaleFLipAug with above configuration, the results are wrapped\n    into lists of the same length as followed:\n    .. code-block::\n        dict(\n            img=[...],\n            img_shape=[...],\n            scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]\n            flip=[False, True, False, True]\n            ...\n        )\n    Args:\n        transforms (list[dict]): Transforms to apply in each augmentation.\n        img_scale (tuple | list[tuple] | None): Images scales for resizing.\n        scale_factor (float | list[float] | None): Scale factors for resizing.\n        flip (bool): Whether apply flip augmentation. Default: False.\n        flip_direction (str | list[str]): Flip augmentation directions,\n            options are \"horizontal\", \"vertical\" and \"diagonal\". If\n            flip_direction is a list, multiple flip augmentations will be\n            applied. It has no effect when flip == False. Default:\n            \"horizontal\".\n    \"\"\"\n\n    def __init__(self,\n                 transforms,\n                 img_scale=None,\n                 scale_factor=None,\n                 flip=False,\n                 flip_direction='horizontal'):\n        self.transforms = Compose(transforms)\n        assert (img_scale is None) ^ (scale_factor is None), (\n            'Must have but only one variable can be set')\n        if img_scale is not None:\n            self.img_scale = img_scale if isinstance(img_scale,\n                                                     list) else [img_scale]\n            self.scale_key = 'scale'\n            assert mmcv.is_list_of(self.img_scale, tuple)\n        else:\n            self.img_scale = scale_factor if isinstance(\n                scale_factor, list) else [scale_factor]\n            self.scale_key = 'scale_factor'\n\n        self.flip = flip\n        self.flip_direction = flip_direction if isinstance(\n            flip_direction, list) else [flip_direction]\n        assert mmcv.is_list_of(self.flip_direction, str)\n        if not self.flip and self.flip_direction != ['horizontal']:\n            warnings.warn(\n                'flip_direction has no effect when flip is set to False')\n        if (self.flip\n                and not any([t['type'] == 'RandomFlip' for t in transforms])):\n            warnings.warn(\n                'flip has no effect when RandomFlip is not in transforms')\n\n    def __call__(self, results):\n        \"\"\"Call function to apply test time augment transforms on results.\n        Args:\n            results (dict): Result dict contains the data to transform.\n        Returns:\n           dict[str: list]: The augmented data, where each value is wrapped\n               into a list.\n        \"\"\"\n\n        aug_data = []\n        flip_args = [(False, None)]\n        if self.flip:\n            flip_args += [(True, direction)\n                          for direction in self.flip_direction]\n        for scale in self.img_scale:\n            for flip, direction in flip_args:\n                _results = []\n                for results_single in results:\n                    _results_single = results_single.copy()\n                    _results_single[self.scale_key] = scale\n                    _results_single['flip'] = flip\n                    _results_single['flip_direction'] = direction\n                    _results.append(_results_single)\n                data = self.transforms(_results)\n                aug_data.append(data)\n        # list of dict to dict of list\n        aug_data_dict = {key: [] for key in aug_data[0]}\n        for data in aug_data:\n            for key, val in data.items():\n                aug_data_dict[key].append(val)\n        return aug_data_dict\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(transforms={self.transforms}, '\n        repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '\n        repr_str += f'flip_direction={self.flip_direction})'\n        return repr_str"
  },
  {
    "path": "external/dataset/pipelines/transforms.py",
    "content": "import cv2\nimport mmcv\nimport numpy as np\nimport warnings\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import Normalize, Pad, RandomFlip, Resize\n\n\n@PIPELINES.register_module()\nclass SeqColorAug(object):\n    \"\"\"Color augmention for images.\n    Args:\n        prob (list[float]): The probability to perform color augmention for\n            each image. Defaults to [1.0, 1.0].\n        rgb_var (list[list]]): The values of color augmentaion. Defaults to\n            [[-0.55919361, 0.98062831, -0.41940627],\n            [1.72091413, 0.19879334, -1.82968581],\n            [4.64467907, 4.73710203, 4.88324118]].\n    \"\"\"\n\n    def __init__(self,\n                 prob=[1.0, 1.0],\n                 rgb_var=[[-0.55919361, 0.98062831, -0.41940627],\n                          [1.72091413, 0.19879334, -1.82968581],\n                          [4.64467907, 4.73710203, 4.88324118]]):\n        self.prob = prob\n        self.rgb_var = np.array(rgb_var, dtype=np.float32)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, perform color augmention for image in the\n        dict.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains augmented color image.\n        \"\"\"\n        outs = []\n        for i, _results in enumerate(results):\n            image = _results['img']\n\n            if self.prob[i] > np.random.random():\n                offset = np.dot(self.rgb_var, np.random.randn(3, 1))\n                # bgr to rgb\n                offset = offset[::-1]\n                offset = offset.reshape(3)\n                image = (image - offset).astype(np.float32)\n\n            _results['img'] = image\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqBlurAug(object):\n    \"\"\"Blur augmention for images.\n    Args:\n        prob (list[float]): The probability to perform blur augmention for\n            each image. Defaults to [0.0, 0.2].\n    \"\"\"\n\n    def __init__(self, prob=[0.0, 0.2]):\n        self.prob = prob\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, perform blur augmention for image in the\n        dict.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains augmented blur image.\n        \"\"\"\n        outs = []\n        for i, _results in enumerate(results):\n            image = _results['img']\n\n            if self.prob[i] > np.random.random():\n                sizes = np.arange(5, 46, 2)\n                size = np.random.choice(sizes)\n                kernel = np.zeros((size, size))\n                c = int(size / 2)\n                wx = np.random.random()\n                kernel[:, c] += 1. / size * wx\n                kernel[c, :] += 1. / size * (1 - wx)\n                image = cv2.filter2D(image, -1, kernel)\n\n            _results['img'] = image\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqResize(Resize):\n    \"\"\"Resize images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Resize` for\n    detailed docstring.\n    Args:\n        share_params (bool): If True, share the resize parameters for all\n            images. Defaults to True.\n    \"\"\"\n\n    def __init__(self, share_params=True, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.share_params = share_params\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Resize` to resize\n        image and corresponding annotations.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains resized results,\n            'img_shape', 'pad_shape', 'scale_factor', 'keep_ratio' keys\n            are added into result dict.\n        \"\"\"\n        outs, scale = [], None\n        for i, _results in enumerate(results):\n            if self.share_params and i > 0:\n                _results['scale'] = scale\n            _results = super().__call__(_results)\n            if self.share_params and i == 0:\n                scale = _results['scale']\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqNormalize(Normalize):\n    \"\"\"Normalize images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Normalize` for\n    detailed docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Normalize` to\n        normalize image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains normalized results,\n            'img_norm_cfg' key is added into result dict.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqRandomFlip(RandomFlip):\n    \"\"\"Randomly flip for images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:RandomFlip` for\n    detailed docstring.\n    Args:\n        share_params (bool): If True, share the flip parameters for all images.\n            Defaults to True.\n    \"\"\"\n\n    def __init__(self, share_params, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.share_params = share_params\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call `RandomFlip` to randomly flip image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains flipped results, 'flip',\n            'flip_direction' keys are added into the dict.\n        \"\"\"\n        if self.share_params:\n            if isinstance(self.direction, list):\n                # None means non-flip\n                direction_list = self.direction + [None]\n            else:\n                # None means non-flip\n                direction_list = [self.direction, None]\n\n            if isinstance(self.flip_ratio, list):\n                non_flip_ratio = 1 - sum(self.flip_ratio)\n                flip_ratio_list = self.flip_ratio + [non_flip_ratio]\n            else:\n                non_flip_ratio = 1 - self.flip_ratio\n                # exclude non-flip\n                single_ratio = self.flip_ratio / (len(direction_list) - 1)\n                flip_ratio_list = [single_ratio] * (len(direction_list) -\n                                                    1) + [non_flip_ratio]\n\n            cur_dir = np.random.choice(direction_list, p=flip_ratio_list)\n            flip = cur_dir is not None\n            flip_direction = cur_dir\n\n            for _results in results:\n                _results['flip'] = flip\n                _results['flip_direction'] = flip_direction\n\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqPad(Pad):\n    \"\"\"Pad images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Pad` for detailed\n    docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Pad` to pad image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains padding results,\n            'pad_shape', 'pad_fixed_size' and 'pad_size_divisor' keys are\n            added into the dict.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqRandomCrop(object):\n    \"\"\"Sequentially random crop the images & bboxes & masks.\n    The absolute `crop_size` is sampled based on `crop_type` and `image_size`,\n    then the cropped results are generated.\n    Args:\n        crop_size (tuple): The relative ratio or absolute pixels of\n            height and width.\n        allow_negative_crop (bool, optional): Whether to allow a crop that does\n            not contain any bbox area. Default False.\n        share_params (bool, optional): Whether share the cropping parameters\n            for the images.\n        bbox_clip_border (bool, optional): Whether clip the objects outside\n            the border of the image. Defaults to True.\n    Note:\n        - If the image is smaller than the absolute crop size, return the\n            original image.\n        - The keys for bboxes, labels and masks must be aligned. That is,\n          `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and\n          `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and\n          `gt_masks_ignore`.\n        - If the crop does not contain any gt-bbox region and\n          `allow_negative_crop` is set to False, skip this image.\n    \"\"\"\n\n    def __init__(self,\n                 crop_size,\n                 allow_negative_crop=False,\n                 share_params=False,\n                 bbox_clip_border=True,\n                 check_id_match=True\n                 ):\n        assert crop_size[0] > 0 and crop_size[1] > 0\n        self.crop_size = crop_size\n        self.allow_negative_crop = allow_negative_crop\n        self.share_params = share_params\n        self.bbox_clip_border = bbox_clip_border\n        self.check_id_match = check_id_match\n        # The key correspondence from bboxes to labels and masks.\n        self.bbox2label = {\n            'gt_bboxes': ['gt_labels', 'gt_instance_ids'],\n            'gt_bboxes_ignore': ['gt_labels_ignore', 'gt_instance_ids_ignore']\n        }\n        self.bbox2mask = {\n            'gt_bboxes': 'gt_masks',\n            'gt_bboxes_ignore': 'gt_masks_ignore'\n        }\n\n    def get_offsets(self, img):\n        \"\"\"Random generate the offsets for cropping.\"\"\"\n        margin_h = max(img.shape[0] - self.crop_size[0], 0)\n        margin_w = max(img.shape[1] - self.crop_size[1], 0)\n        offset_h = np.random.randint(0, margin_h + 1)\n        offset_w = np.random.randint(0, margin_w + 1)\n        return offset_h, offset_w\n\n    def random_crop(self, results, offsets=None):\n        \"\"\"Call function to randomly crop images, bounding boxes, masks,\n        semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n            offsets (tuple, optional): Pre-defined offsets for cropping.\n                Default to None.\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n            updated according to crop size.\n        \"\"\"\n\n        for key in results.get('img_fields', ['img']):\n            img = results[key]\n            if offsets is not None:\n                offset_h, offset_w = offsets\n            else:\n                offset_h, offset_w = self.get_offsets(img)\n            results['img_info']['crop_offsets'] = (offset_h, offset_w)\n            crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]\n            crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]\n\n            # crop the image\n            img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]\n            img_shape = img.shape\n            results[key] = img\n        results['img_shape'] = img_shape\n\n        # crop bboxes accordingly and clip to the image boundary\n        for key in results.get('bbox_fields', []):\n            # e.g. gt_bboxes and gt_bboxes_ignore\n            bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],\n                                   dtype=np.float32)\n            bboxes = results[key] - bbox_offset\n            if self.bbox_clip_border:\n                bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])\n                bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])\n            valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (\n                    bboxes[:, 3] > bboxes[:, 1])\n            # If the crop does not contain any gt-bbox area and\n            # self.allow_negative_crop is False, skip this image.\n            if (key == 'gt_bboxes' and not valid_inds.any()\n                    and not self.allow_negative_crop):\n                return None\n            results[key] = bboxes[valid_inds, :]\n            # label fields. e.g. gt_labels and gt_labels_ignore\n            label_keys = self.bbox2label.get(key)\n            for label_key in label_keys:\n                if label_key in results:\n                    results[label_key] = results[label_key][valid_inds]\n\n            # mask fields, e.g. gt_masks and gt_masks_ignore\n            mask_key = self.bbox2mask.get(key)\n            if mask_key in results:\n                results[mask_key] = results[mask_key][\n                    valid_inds.nonzero()[0]].crop(\n                    np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))\n\n        # crop semantic seg\n        for key in results.get('seg_fields', []):\n            results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to sequentially randomly crop images, bounding boxes,\n        masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n            updated according to crop size.\n        \"\"\"\n        if self.share_params:\n            offsets = self.get_offsets(results[0]['img'])\n        else:\n            offsets = None\n\n        outs = []\n        for _results in results:\n            _results = self.random_crop(_results, offsets)\n            if _results is None:\n                return None\n            outs.append(_results)\n\n        if len(outs) == 2 and self.check_id_match:\n            ref_result, result = outs[1], outs[0]\n            if self.check_match(ref_result, result):\n                return None\n        return outs\n\n    def check_match(self, ref_results, results):\n        ref_ids = ref_results['gt_instance_ids'].tolist()\n        gt_ids = results['gt_instance_ids'].tolist()\n        gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n        nomatch = (np.array(gt_pids) == -1).all()\n        return nomatch\n\n\n@PIPELINES.register_module()\nclass SeqPhotoMetricDistortion(object):\n    \"\"\"Apply photometric distortion to image sequentially, every transformation\n    is applied with a probability of 0.5. The position of random contrast is in\n    second or second to last.\n    1. random brightness\n    2. random contrast (mode 0)\n    3. convert color from BGR to HSV\n    4. random saturation\n    5. random hue\n    6. convert color from HSV to BGR\n    7. random contrast (mode 1)\n    8. randomly swap channels\n    Args:\n        brightness_delta (int): delta of brightness.\n        contrast_range (tuple): range of contrast.\n        saturation_range (tuple): range of saturation.\n        hue_delta (int): delta of hue.\n    \"\"\"\n\n    def __init__(self,\n                 share_params=True,\n                 brightness_delta=32,\n                 contrast_range=(0.5, 1.5),\n                 saturation_range=(0.5, 1.5),\n                 hue_delta=18):\n        self.share_params = share_params\n        self.brightness_delta = brightness_delta\n        self.contrast_lower, self.contrast_upper = contrast_range\n        self.saturation_lower, self.saturation_upper = saturation_range\n        self.hue_delta = hue_delta\n\n    def get_params(self):\n        \"\"\"Generate parameters.\"\"\"\n        params = dict()\n        # delta\n        if np.random.randint(2):\n            params['delta'] = np.random.uniform(-self.brightness_delta,\n                                                self.brightness_delta)\n        else:\n            params['delta'] = None\n        # mode\n        mode = np.random.randint(2)\n        params['contrast_first'] = True if mode == 1 else 0\n        # alpha\n        if np.random.randint(2):\n            params['alpha'] = np.random.uniform(self.contrast_lower,\n                                                self.contrast_upper)\n        else:\n            params['alpha'] = None\n        # saturation\n        if np.random.randint(2):\n            params['saturation'] = np.random.uniform(self.saturation_lower,\n                                                     self.saturation_upper)\n        else:\n            params['saturation'] = None\n        # hue\n        if np.random.randint(2):\n            params['hue'] = np.random.uniform(-self.hue_delta, self.hue_delta)\n        else:\n            params['hue'] = None\n        # swap\n        if np.random.randint(2):\n            params['permutation'] = np.random.permutation(3)\n        else:\n            params['permutation'] = None\n        return params\n\n    def photo_metric_distortion(self, results, params=None):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n            params (dict, optional): Pre-defined parameters. Default to None.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        if params is None:\n            params = self.get_params()\n        results['img_info']['color_jitter'] = params\n\n        if 'img_fields' in results:\n            assert results['img_fields'] == ['img'], \\\n                'Only single img_fields is allowed'\n        img = results['img']\n        assert img.dtype == np.float32, \\\n            'PhotoMetricDistortion needs the input image of dtype np.float32,' \\\n            ' please set \"to_float32=True\" in \"LoadImageFromFile\" pipeline'\n        # random brightness\n        if params['delta'] is not None:\n            img += params['delta']\n\n        # mode == 0 --> do random contrast first\n        # mode == 1 --> do random contrast last\n        if params['contrast_first']:\n            if params['alpha'] is not None:\n                img *= params['alpha']\n\n        # convert color from BGR to HSV\n        img = mmcv.bgr2hsv(img)\n\n        # random saturation\n        if params['saturation'] is not None:\n            img[..., 1] *= params['saturation']\n\n        # random hue\n        if params['hue'] is not None:\n            img[..., 0] += params['hue']\n            img[..., 0][img[..., 0] > 360] -= 360\n            img[..., 0][img[..., 0] < 0] += 360\n\n        # convert color from HSV to BGR\n        img = mmcv.hsv2bgr(img)\n\n        # random contrast\n        if not params['contrast_first']:\n            if params['alpha'] is not None:\n                img *= params['alpha']\n\n        # randomly swap channels\n        if params['permutation'] is not None:\n            img = img[..., params['permutation']]\n\n        results['img'] = img\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        if self.share_params:\n            params = self.get_params()\n        else:\n            params = None\n\n        outs = []\n        for _results in results:\n            _results = self.photo_metric_distortion(_results, params)\n            outs.append(_results)\n\n        return outs\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(\\nbrightness_delta={self.brightness_delta},\\n'\n        repr_str += 'contrast_range='\n        repr_str += f'{(self.contrast_lower, self.contrast_upper)},\\n'\n        repr_str += 'saturation_range='\n        repr_str += f'{(self.saturation_lower, self.saturation_upper)},\\n'\n        repr_str += f'hue_delta={self.hue_delta})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass ResizeWithRef(object):\n    \"\"\"Resize images & bbox & mask.\n\n    This transform resizes the input image to some scale. Bboxes and masks are\n    then resized with the same scale factor. If the input dict contains the key\n    \"scale\", then the scale in the input dict is used, otherwise the specified\n    scale in the init method is used.\n\n    `img_scale` can either be a tuple (single-scale) or a list of tuple\n    (multi-scale). There are 3 multiscale modes:\n    - `ratio_range` is not None: randomly sample a ratio from the ratio range\n        and multiply it with the image scale.\n    - `ratio_range` is None and `multiscale_mode` == \"range\": randomly sample a\n        scale from the a range.\n    - `ratio_range` is None and `multiscale_mode` == \"value\": randomly sample a\n        scale from multiple scales.\n\n    Args:\n        img_scale (tuple or list[tuple]): Images scales for resizing.\n        multiscale_mode (str): Either \"range\" or \"value\".\n        ratio_range (tuple[float]): (min_ratio, max_ratio)\n        keep_ratio (bool): Whether to keep the aspect ratio when resizing the\n            image.\n    \"\"\"\n\n    def __init__(self,\n                 img_scale=None,\n                 multiscale_mode='range',\n                 ratio_range=None,\n                 keep_ratio=True):\n        if img_scale is None:\n            self.img_scale = None\n        else:\n            if isinstance(img_scale, list):\n                self.img_scale = img_scale\n            else:\n                self.img_scale = [img_scale]\n            assert mmcv.is_list_of(self.img_scale, tuple)\n\n        if ratio_range is not None:\n            # mode 1: given a scale and a range of image ratio\n            assert len(self.img_scale) == 1\n        else:\n            # mode 2: given multiple scales or a range of scales\n            assert multiscale_mode in ['value', 'range']\n\n        self.multiscale_mode = multiscale_mode\n        self.ratio_range = ratio_range\n        self.keep_ratio = keep_ratio\n\n    @staticmethod\n    def random_select(img_scales):\n        assert mmcv.is_list_of(img_scales, tuple)\n        scale_idx = np.random.randint(len(img_scales))\n        img_scale = img_scales[scale_idx]\n        return img_scale, scale_idx\n\n    @staticmethod\n    def random_sample(img_scales):\n        assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2\n        img_scale_long = [max(s) for s in img_scales]\n        img_scale_short = [min(s) for s in img_scales]\n        long_edge = np.random.randint(\n            min(img_scale_long),\n            max(img_scale_long) + 1)\n        short_edge = np.random.randint(\n            min(img_scale_short),\n            max(img_scale_short) + 1)\n        img_scale = (long_edge, short_edge)\n        return img_scale, None\n\n    @staticmethod\n    def random_sample_ratio(img_scale, ratio_range):\n        assert isinstance(img_scale, tuple) and len(img_scale) == 2\n        min_ratio, max_ratio = ratio_range\n        assert min_ratio <= max_ratio\n        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio\n        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)\n        return scale, None\n\n    def _random_scale(self, results):\n        if self.ratio_range is not None:\n            scale, scale_idx = self.random_sample_ratio(\n                self.img_scale[0], self.ratio_range)\n        elif len(self.img_scale) == 1:\n            scale, scale_idx = self.img_scale[0], 0\n        elif self.multiscale_mode == 'range':\n            scale, scale_idx = self.random_sample(self.img_scale)\n        elif self.multiscale_mode == 'value':\n            scale, scale_idx = self.random_select(self.img_scale)\n        else:\n            raise NotImplementedError\n\n        results['scale'] = scale\n        results['scale_idx'] = scale_idx\n\n    def _resize_img(self, results):\n        els = ['ref_img', 'img'] if 'ref_img' in results else ['img']\n        for el in els:\n            if self.keep_ratio:\n                img, scale_factor = mmcv.imrescale(\n                    results[el], results['scale'], return_scale=True)\n            else:\n                img, w_scale, h_scale = mmcv.imresize(\n                    results[el], results['scale'], return_scale=True)\n                scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],\n                                        dtype=np.float32)\n            results[el] = img\n        results['img_shape'] = img.shape\n        results['pad_shape'] = img.shape  # in case that there is no padding\n        results['scale_factor'] = scale_factor\n        results['keep_ratio'] = self.keep_ratio\n\n    def _resize_bboxes(self, results):\n        els = ['ref_bbox_fields', 'bbox_fields'] if 'ref_bbox_fields' in results else ['bbox_fields']\n        for el in els:\n            img_shape = results['img_shape']\n            for key in results.get(el, []):\n                bboxes = results[key] * results['scale_factor']\n                bboxes[:, 0::2] = np.clip(\n                    bboxes[:, 0::2], 0, img_shape[1] - 1)\n                bboxes[:, 1::2] = np.clip(\n                    bboxes[:, 1::2], 0, img_shape[0] - 1)\n                results[key] = bboxes\n\n    def _resize_masks(self, results):\n        els = ['ref_mask_fields', 'mask_fields'] if 'ref_mask_fields' in results else ['mask_fields']\n        for el in els:\n            for key in results.get(el, []):\n                if results[key] is None:\n                    continue\n                if self.keep_ratio:\n                    masks = [\n                        mmcv.imrescale(\n                            mask, results['scale_factor'],\n                            interpolation='nearest')\n                        for mask in results[key]\n                    ]\n                else:\n                    mask_size = (results['img_shape'][1],\n                                 results['img_shape'][0])\n                    masks = [\n                        mmcv.imresize(mask, mask_size,\n                                      interpolation='nearest')\n                        for mask in results[key]\n                    ]\n                results[key] = masks\n\n    def __call__(self, results):\n        if 'scale' not in results:\n            self._random_scale(results)\n        self._resize_img(results)\n        self._resize_bboxes(results)\n        self._resize_masks(results)\n        # self._resize_semantic_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += ('(img_scale={}, multiscale_mode={}, ratio_range={}, '\n                     'keep_ratio={})').format(self.img_scale,\n                                              self.multiscale_mode,\n                                              self.ratio_range,\n                                              self.keep_ratio)\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomFlipWithRef(object):\n    \"\"\"Flip the image & bbox & mask.\n\n    If the input dict contains the key \"flip\", then the flag will be used,\n    otherwise it will be randomly decided by a ratio specified in the init\n    method.\n\n    Args:\n        flip_ratio (float, optional): The flipping probability.\n    \"\"\"\n\n    def __init__(self, flip_ratio=None):\n        self.flip_ratio = flip_ratio\n        if flip_ratio is not None:\n            assert flip_ratio >= 0 and flip_ratio <= 1\n\n    def bbox_flip(self, bboxes, img_shape):\n        \"\"\"Flip bboxes horizontally.\n\n        Args:\n            bboxes(ndarray): shape (..., 4*k)\n            img_shape(tuple): (height, width)\n        \"\"\"\n        assert bboxes.shape[-1] % 4 == 0\n        w = img_shape[1]\n        flipped = bboxes.copy()\n        flipped[..., 0::4] = w - bboxes[..., 2::4] - 1\n        flipped[..., 2::4] = w - bboxes[..., 0::4] - 1\n        return flipped\n\n    def __call__(self, results):\n        if 'flip' not in results:\n            flip = True if np.random.rand() < self.flip_ratio else False\n            results['flip'] = flip\n        if results['flip']:\n            # flip image\n            results['img'] = mmcv.imflip(results['img'])\n            if 'ref_img' in results:\n                results['ref_img'] = mmcv.imflip(results['ref_img'])\n            # flip bboxes\n            for key in results.get('bbox_fields', []):\n                results[key] = self.bbox_flip(results[key],\n                                              results['img_shape'])\n            for key in results.get('ref_bbox_fields', []):\n                results[key] = self.bbox_flip(results[key],\n                                              results['img_shape'])\n            # flip masks\n            for key in results.get('mask_fields', []):\n                results[key] = [mask[:, ::-1] for mask in results[key]]\n            for key in results.get('ref_mask_fields', []):\n                results[key] = [mask[:, ::-1] for mask in results[key]]\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(flip_ratio={})'.format(\n            self.flip_ratio)\n\n\n@PIPELINES.register_module()\nclass PadWithRef(object):\n    \"\"\"Pad the image & mask.\n\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_val (float, optional): Padding value, 0 by default.\n    \"\"\"\n\n    def __init__(self, size=None, size_divisor=None, pad_val=0):\n        self.size = size\n        self.size_divisor = size_divisor\n        self.pad_val = pad_val\n        # only one of size and size_divisor should be valid\n        assert size is not None or size_divisor is not None\n        assert size is None or size_divisor is None\n\n    def _pad_img(self, results):\n        els = ['ref_img', 'img'] if 'ref_img' in results else ['img']\n        for el in els:\n            if self.size is not None:\n                padded_img = mmcv.impad(results['img'], self.size)\n            elif self.size_divisor is not None:\n                padded_img = mmcv.impad_to_multiple(\n                    results[el], self.size_divisor, pad_val=self.pad_val)\n            results[el] = padded_img\n        results['pad_shape'] = padded_img.shape\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def _pad_masks(self, results):\n        els = ['ref_mask_fields', 'mask_fields'] if 'ref_mask_fields' in results else ['mask_fields']\n        for el in els:\n            pad_shape = results['pad_shape'][:2]\n            for key in results.get(el, []):\n                padded_masks = [\n                    mmcv.impad(mask, pad_shape, pad_val=self.pad_val)\n                    for mask in results[key]\n                ]\n                results[key] = np.stack(padded_masks, axis=0)\n\n    def __call__(self, results):\n        self._pad_img(results)\n        self._pad_masks(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += '(size={}, size_divisor={}, pad_val={})'.format(\n            self.size, self.size_divisor, self.pad_val)\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass NormalizeWithRef(object):\n    \"\"\"Normalize the image.\n\n    Args:\n        mean (sequence): Mean values of 3 channels.\n        std (sequence): Std values of 3 channels.\n        to_rgb (bool): Whether to convert the image from BGR to RGB,\n            default is true.\n    \"\"\"\n\n    def __init__(self, mean, std, to_rgb=True):\n        self.mean = np.array(mean, dtype=np.float32)\n        self.std = np.array(std, dtype=np.float32)\n        self.to_rgb = to_rgb\n\n    def __call__(self, results):\n        results['img'] = mmcv.imnormalize(\n            results['img'], self.mean, self.std, self.to_rgb)\n        if 'ref_img' in results:\n            results['ref_img'] = mmcv.imnormalize(\n                results['ref_img'], self.mean, self.std, self.to_rgb)\n        results['img_norm_cfg'] = dict(\n            mean=self.mean, std=self.std, to_rgb=self.to_rgb)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += '(mean={}, std={}, to_rgb={})'.format(\n            self.mean, self.std, self.to_rgb)\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomCropWithRef(object):\n    \"\"\"Random crop the image & bboxes & masks.\n\n    Args:\n        crop_size (tuple): Expected size after cropping, (h, w).\n    \"\"\"\n\n    def __init__(self, crop_size):\n        self.crop_size = crop_size\n\n    def __call__(self, results):\n        img = results['img']\n\n        margin_h = max(img.shape[0] - self.crop_size[0], 0)\n        margin_w = max(img.shape[1] - self.crop_size[1], 0)\n        offset_h = np.random.randint(0, margin_h + 1)\n        offset_w = np.random.randint(0, margin_w + 1)\n        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]\n        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]\n\n        # crop the image\n        ori_shape = img.shape\n        img = img[crop_y1:crop_y2, crop_x1:crop_x2, :]\n        img_shape = img.shape\n        results['img'] = img\n        if 'ref_img' in results:\n            ref_img = results['ref_img']\n            ref_img = ref_img[crop_y1:crop_y2, crop_x1:crop_x2, :]\n            results['ref_img'] = ref_img\n        results['img_shape'] = img_shape\n        results['crop_coords'] = [crop_y1, crop_y2, crop_x1, crop_x2]\n\n        # crop bboxes accordingly and clip to the image boundary\n        els = ['ref_bbox_fields', 'bbox_fields'] if 'ref_bbox_fields' in results else ['bbox_fields']\n        for el in els:\n            for key in results.get(el, []):\n                bbox_offset = np.array(\n                    [offset_w, offset_h, offset_w, offset_h],\n                    dtype=np.float32)\n                bboxes = results[key] - bbox_offset\n                bboxes[:, 0::2] = np.clip(\n                    bboxes[:, 0::2], 0, img_shape[1] - 1)\n                bboxes[:, 1::2] = np.clip(\n                    bboxes[:, 1::2], 0, img_shape[0] - 1)\n                results[key] = bboxes\n\n        # filter out the gt bboxes that are completely cropped\n        els = ['ref_bboxes', 'gt_bboxes'] if 'ref_bboxes' in results else ['gt_bboxes']\n        for el in els:\n            if el in results:\n                gt_bboxes = results[el]\n                valid_inds = (gt_bboxes[:, 2] > gt_bboxes[:, 0]) & (\n                        gt_bboxes[:, 3] > gt_bboxes[:, 1])\n                # if no gt bbox remains after cropping, just skip this image\n                if not np.any(valid_inds):\n                    return None\n                results[el] = gt_bboxes[valid_inds, :]\n                ell = el.replace('_bboxes', '_labels')\n                if ell in results:\n                    results[ell] = results[ell][valid_inds]\n                #### filter gt_obj_ids just like gt_labes.\n                elo = el.replace('_bboxes', '_obj_ids')\n                if elo in results:\n                    results[elo] = results[elo][valid_inds]\n                # filter and crop the masks\n                elm = el.replace('_bboxes', '_masks')\n                if elm in results:\n                    valid_gt_masks = []\n                    for i in np.where(valid_inds)[0]:\n                        gt_mask = results[elm][i][\n                                  crop_y1:crop_y2, crop_x1:crop_x2]\n                        valid_gt_masks.append(gt_mask)\n                    results[elm] = valid_gt_masks\n\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(crop_size={})'.format(\n            self.crop_size)\n\n\n@PIPELINES.register_module()\nclass PadFutureMMDet:\n    \"\"\"Pad the image & masks & segmentation map.\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n    Added keys are \"pad_shape\", \"pad_fixed_size\", \"pad_size_divisor\",\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_to_square (bool): Whether to pad the image into a square.\n            Currently only used for YOLOX. Default: False.\n        pad_val (dict, optional): A dict for padding value, the default\n            value is `dict(img=0, masks=0, seg=255)`.\n    \"\"\"\n\n    def __init__(self,\n                 size=None,\n                 size_divisor=None,\n                 pad_to_square=False,\n                 pad_val=dict(img=0, masks=0, seg=255)):\n        self.size = size\n        self.size_divisor = size_divisor\n        if isinstance(pad_val, float) or isinstance(pad_val, int):\n            warnings.warn(\n                'pad_val of float type is deprecated now, '\n                f'please use pad_val=dict(img={pad_val}, '\n                f'masks={pad_val}, seg=255) instead.', DeprecationWarning)\n            pad_val = dict(img=pad_val, masks=pad_val, seg=255)\n        assert isinstance(pad_val, dict)\n        self.pad_val = pad_val\n        self.pad_to_square = pad_to_square\n\n        if pad_to_square:\n            assert size is None and size_divisor is None, \\\n                'The size and size_divisor must be None ' \\\n                'when pad2square is True'\n        else:\n            assert size is not None or size_divisor is not None, \\\n                'only one of size and size_divisor should be valid'\n            assert size is None or size_divisor is None\n\n    def _pad_img(self, results):\n        \"\"\"Pad images according to ``self.size``.\"\"\"\n        pad_val = self.pad_val.get('img', 0)\n        for key in results.get('img_fields', ['img']):\n            if self.pad_to_square:\n                max_size = max(results[key].shape[:2])\n                self.size = (max_size, max_size)\n            if self.size is not None:\n                padded_img = mmcv.impad(\n                    results[key], shape=self.size, pad_val=pad_val)\n            elif self.size_divisor is not None:\n                padded_img = mmcv.impad_to_multiple(\n                    results[key], self.size_divisor, pad_val=pad_val)\n            results[key] = padded_img\n        results['pad_shape'] = padded_img.shape\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def _pad_masks(self, results):\n        \"\"\"Pad masks according to ``results['pad_shape']``.\"\"\"\n        pad_shape = results['pad_shape'][:2]\n        pad_val = self.pad_val.get('masks', 0)\n        for key in results.get('mask_fields', []):\n            results[key] = results[key].pad(pad_shape, pad_val=pad_val)\n\n    def _pad_seg(self, results):\n        \"\"\"Pad semantic segmentation map according to\n        ``results['pad_shape']``.\"\"\"\n        pad_val = self.pad_val.get('seg', 255)\n        for key in results.get('seg_fields', []):\n            results[key] = mmcv.impad(\n                results[key], shape=results['pad_shape'][:2], pad_val=pad_val)\n\n    def __call__(self, results):\n        \"\"\"Call function to pad images, masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        self._pad_img(results)\n        self._pad_masks(results)\n        self._pad_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(size={self.size}, '\n        repr_str += f'size_divisor={self.size_divisor}, '\n        repr_str += f'pad_to_square={self.pad_to_square}, '\n        repr_str += f'pad_val={self.pad_val})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass KNetInsAdapter:\n    \"\"\"Adapter that is used to convert city-style instance class-ids\n    to coco-style instance-ids (11-starting to 0-starting)\n    \"\"\"\n\n    def __init__(self, stuff_nums=11):\n        self.stuff_nums = stuff_nums\n\n    def __call__(self, results):\n        \"\"\"Call function to modify gt_labels\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        results['gt_labels'] -= self.stuff_nums\n        return results\n\n\n@PIPELINES.register_module()\nclass KNetInsAdapterCherryPick:\n    \"\"\"Adapter that is used to convert city-style instance class-ids\n    to coco-style instance-ids (11-starting to 0-starting)\n    \"\"\"\n\n    def __init__(self, stuff_nums=11, cherry=(11, 13)):\n        self.cherry = cherry\n        self.stuff_nums = stuff_nums\n\n    def __call__(self, results):\n        \"\"\"Call function to modify gt_labels\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        bias = 0\n        for ch in self.cherry:\n            results['gt_labels'][results['gt_labels'] == ch] -= bias\n            bias += 1\n        results['gt_labels'] -= self.stuff_nums\n        return results\n"
  },
  {
    "path": "external/evalhooks.py",
    "content": "import os.path as osp\nimport warnings\nfrom math import inf\n\nimport mmcv\nimport torch.distributed as dist\nfrom mmcv.runner import Hook\nfrom mmdet.utils import get_root_logger\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.utils.data import DataLoader\n\nfrom external.test import multi_gpu_test, single_gpu_test\n\n\nclass EvalHook(Hook):\n    \"\"\"Evaluation hook.\n\n    Notes:\n        If new arguments are added for EvalHook, tools/test.py,\n        tools/analysis_tools/eval_metric.py may be effected.\n\n    Attributes:\n        dataloader (DataLoader): A PyTorch dataloader.\n        start (int, optional): Evaluation starting epoch. It enables evaluation\n            before the training starts if ``start`` <= the resuming epoch.\n            If None, whether to evaluate is merely decided by ``interval``.\n            Default: None.\n        interval (int): Evaluation interval (by epochs). Default: 1.\n        save_best (str, optional): If a metric is specified, it would measure\n            the best checkpoint during evaluation. The information about best\n            checkpoint would be save in best.json.\n            Options are the evaluation metrics to the test dataset. e.g.,\n            ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance\n            segmentation. ``AR@100`` for proposal recall. If ``save_best`` is\n            ``auto``, the first key will be used. The interval of\n            ``CheckpointHook`` should device EvalHook. Default: None.\n        rule (str, optional): Comparison rule for best score. If set to None,\n            it will infer a reasonable rule. Keys such as 'mAP' or 'AR' will\n            be inferred by 'greater' rule. Keys contain 'loss' will be inferred\n             by 'less' rule. Options are 'greater', 'less'. Default: None.\n        **eval_kwargs: Evaluation arguments fed into the evaluate function of\n            the dataset.\n    \"\"\"\n\n    rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}\n    init_value_map = {'greater': -inf, 'less': inf}\n    greater_keys = ['mAP', 'AR']\n    less_keys = ['loss']\n\n    def __init__(self,\n                 dataloader,\n                 start=None,\n                 interval=1,\n                 by_epoch=True,\n                 save_best=None,\n                 rule=None,\n                 **eval_kwargs):\n        if not isinstance(dataloader, DataLoader):\n            raise TypeError('dataloader must be a pytorch DataLoader, but got'\n                            f' {type(dataloader)}')\n        if not interval > 0:\n            raise ValueError(f'interval must be positive, but got {interval}')\n        if start is not None and start < 0:\n            warnings.warn(\n                f'The evaluation start epoch {start} is smaller than 0, '\n                f'use 0 instead', UserWarning)\n            start = 0\n        self.dataloader = dataloader\n        self.interval = interval\n        self.by_epoch = by_epoch\n        self.start = start\n        assert isinstance(save_best, str) or save_best is None\n        self.save_best = save_best\n        self.eval_kwargs = eval_kwargs\n        self.initial_epoch_flag = True\n\n        self.logger = get_root_logger()\n\n        if self.save_best is not None:\n            self._init_rule(rule, self.save_best)\n\n    def _init_rule(self, rule, key_indicator):\n        \"\"\"Initialize rule, key_indicator, comparison_func, and best score.\n\n        Args:\n            rule (str | None): Comparison rule for best score.\n            key_indicator (str | None): Key indicator to determine the\n                comparison rule.\n        \"\"\"\n        if rule not in self.rule_map and rule is not None:\n            raise KeyError(f'rule must be greater, less or None, '\n                           f'but got {rule}.')\n\n        if rule is None:\n            if key_indicator != 'auto':\n                if any(key in key_indicator for key in self.greater_keys):\n                    rule = 'greater'\n                elif any(key in key_indicator for key in self.less_keys):\n                    rule = 'less'\n                else:\n                    raise ValueError(f'Cannot infer the rule for key '\n                                     f'{key_indicator}, thus a specific rule '\n                                     f'must be specified.')\n        self.rule = rule\n        self.key_indicator = key_indicator\n        if self.rule is not None:\n            self.compare_func = self.rule_map[self.rule]\n\n    def before_run(self, runner):\n        if self.save_best is not None:\n            if runner.meta is None:\n                warnings.warn('runner.meta is None. Creating a empty one.')\n                runner.meta = dict()\n            runner.meta.setdefault('hook_msgs', dict())\n\n    def before_train_epoch(self, runner):\n        \"\"\"Evaluate the model only at the start of training.\"\"\"\n        if not self.initial_epoch_flag:\n            return\n        if self.start is not None and runner.epoch >= self.start:\n            self.after_train_epoch(runner)\n        self.initial_epoch_flag = False\n\n    def evaluation_flag(self, runner):\n        \"\"\"Judge whether to perform_evaluation after this epoch.\n\n        Returns:\n            bool: The flag indicating whether to perform evaluation.\n        \"\"\"\n        if self.start is None:\n            if not self.every_n_epochs(runner, self.interval):\n                # No evaluation during the interval epochs.\n                return False\n        elif (runner.epoch + 1) < self.start:\n            # No evaluation if start is larger than the current epoch.\n            return False\n        else:\n            # Evaluation only at epochs 3, 5, 7... if start==3 and interval==2\n            if (runner.epoch + 1 - self.start) % self.interval:\n                return False\n        return True\n\n    def after_train_epoch(self, runner):\n        if not self.by_epoch or not self.evaluation_flag(runner):\n            return\n        results = single_gpu_test(runner.model, self.dataloader, show=False)\n        key_score = self.evaluate(runner, results)\n        if self.save_best:\n            self.save_best_checkpoint(runner, key_score)\n\n    def after_train_iter(self, runner):\n        if self.by_epoch or not self.every_n_iters(runner, self.interval):\n            return\n        results = single_gpu_test(runner.model, self.dataloader, show=False)\n        key_score = self.evaluate(runner, results)\n        if self.save_best:\n            self.save_best_checkpoint(runner, key_score)\n\n    def save_best_checkpoint(self, runner, key_score):\n        best_score = runner.meta['hook_msgs'].get(\n            'best_score', self.init_value_map[self.rule])\n        if self.compare_func(key_score, best_score):\n            best_score = key_score\n            runner.meta['hook_msgs']['best_score'] = best_score\n            last_ckpt = runner.meta['hook_msgs']['last_ckpt']\n            runner.meta['hook_msgs']['best_ckpt'] = last_ckpt\n            mmcv.symlink(\n                last_ckpt,\n                osp.join(runner.work_dir, f'best_{self.key_indicator}.pth'))\n            time_stamp = runner.epoch + 1 if self.by_epoch else runner.iter + 1\n            self.logger.info(f'Now best checkpoint is epoch_{time_stamp}.pth.'\n                             f'Best {self.key_indicator} is {best_score:0.4f}')\n\n    def evaluate(self, runner, results):\n        eval_res = self.dataloader.dataset.evaluate(\n            results, logger=runner.logger, **self.eval_kwargs)\n        for name, val in eval_res.items():\n            runner.log_buffer.output[name] = val\n        runner.log_buffer.ready = True\n        if self.save_best is not None:\n            if self.key_indicator == 'auto':\n                # infer from eval_results\n                self._init_rule(self.rule, list(eval_res.keys())[0])\n            return eval_res[self.key_indicator]\n        else:\n            return None\n\n\nclass DistEvalHook(EvalHook):\n    \"\"\"Distributed evaluation hook.\n\n    Notes:\n        If new arguments are added, tools/test.py may be effected.\n\n    Attributes:\n        dataloader (DataLoader): A PyTorch dataloader.\n        start (int, optional): Evaluation starting epoch. It enables evaluation\n            before the training starts if ``start`` <= the resuming epoch.\n            If None, whether to evaluate is merely decided by ``interval``.\n            Default: None.\n        interval (int): Evaluation interval (by epochs). Default: 1.\n        tmpdir (str | None): Temporary directory to save the results of all\n            processes. Default: None.\n        gpu_collect (bool): Whether to use gpu or cpu to collect results.\n            Default: False.\n        save_best (str, optional): If a metric is specified, it would measure\n            the best checkpoint during evaluation. The information about best\n            checkpoint would be save in best.json.\n            Options are the evaluation metrics to the test dataset. e.g.,\n            ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance\n            segmentation. ``AR@100`` for proposal recall. If ``save_best`` is\n            ``auto``, the first key will be used. The interval of\n            ``CheckpointHook`` should device EvalHook. Default: None.\n        rule (str | None): Comparison rule for best score. If set to None,\n            it will infer a reasonable rule. Default: 'None'.\n        broadcast_bn_buffer (bool): Whether to broadcast the\n            buffer(running_mean and running_var) of rank 0 to other rank\n            before evaluation. Default: True.\n        **eval_kwargs: Evaluation arguments fed into the evaluate function of\n            the dataset.\n    \"\"\"\n\n    def __init__(self,\n                 dataloader,\n                 start=None,\n                 interval=1,\n                 by_epoch=True,\n                 tmpdir=None,\n                 gpu_collect=False,\n                 save_best=None,\n                 rule=None,\n                 broadcast_bn_buffer=True,\n                 **eval_kwargs):\n        super().__init__(\n            dataloader,\n            start=start,\n            interval=interval,\n            by_epoch=by_epoch,\n            save_best=save_best,\n            rule=rule,\n            **eval_kwargs)\n        self.broadcast_bn_buffer = broadcast_bn_buffer\n        self.tmpdir = tmpdir\n        self.gpu_collect = gpu_collect\n\n    def _broadcast_bn_buffer(self, runner):\n        # Synchronization of BatchNorm's buffer (running_mean\n        # and running_var) is not supported in the DDP of pytorch,\n        # which may cause the inconsistent performance of models in\n        # different ranks, so we broadcast BatchNorm's buffers\n        # of rank 0 to other ranks to avoid this.\n        if self.broadcast_bn_buffer:\n            model = runner.model\n            for name, module in model.named_modules():\n                if isinstance(module,\n                              _BatchNorm) and module.track_running_stats:\n                    dist.broadcast(module.running_var, 0)\n                    dist.broadcast(module.running_mean, 0)\n\n    def after_train_epoch(self, runner):\n        if not self.by_epoch or not self.evaluation_flag(runner):\n            return\n\n        if self.broadcast_bn_buffer:\n            self._broadcast_bn_buffer(runner)\n\n        tmpdir = self.tmpdir\n        if tmpdir is None:\n            tmpdir = osp.join(runner.work_dir, '.eval_hook')\n        results = multi_gpu_test(\n            runner.model,\n            self.dataloader,\n            tmpdir=tmpdir,\n            gpu_collect=self.gpu_collect)\n        if runner.rank == 0:\n            print('\\n')\n            key_score = self.evaluate(runner, results)\n            if self.save_best:\n                self.save_best_checkpoint(runner, key_score)\n\n    def after_train_iter(self, runner):\n        if self.by_epoch or not self.every_n_iters(runner, self.interval):\n            return\n\n        if self.broadcast_bn_buffer:\n            self._broadcast_bn_buffer(runner)\n\n        tmpdir = self.tmpdir\n        if tmpdir is None:\n            tmpdir = osp.join(runner.work_dir, '.eval_hook')\n        results = multi_gpu_test(\n            runner.model,\n            self.dataloader,\n            tmpdir=tmpdir,\n            gpu_collect=self.gpu_collect)\n        if runner.rank == 0:\n            print('\\n')\n            key_score = self.evaluate(runner, results)\n            if self.save_best:\n                self.save_best_checkpoint(runner, key_score)\n"
  },
  {
    "path": "external/ext/mask.py",
    "content": "__author__ = 'tsungyi'\n\nimport pycocotools._mask as _mask\n\n# Interface for manipulating masks stored in RLE format.\n#\n# RLE is a simple yet efficient format for storing binary masks. RLE\n# first divides a vector (or vectorized image) into a series of piecewise\n# constant regions and then for each piece simply stores the length of\n# that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would\n# be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1]\n# (note that the odd counts are always the numbers of zeros). Instead of\n# storing the counts directly, additional compression is achieved with a\n# variable bitrate representation based on a common scheme called LEB128.\n#\n# Compression is greatest given large piecewise constant regions.\n# Specifically, the size of the RLE is proportional to the number of\n# *boundaries* in M (or for an image the number of boundaries in the y\n# direction). Assuming fairly simple shapes, the RLE representation is\n# O(sqrt(n)) where n is number of pixels in the object. Hence space usage\n# is substantially lower, especially for large simple objects (large n).\n#\n# Many common operations on masks can be computed directly using the RLE\n# (without need for decoding). This includes computations such as area,\n# union, intersection, etc. All of these operations are linear in the\n# size of the RLE, in other words they are O(sqrt(n)) where n is the area\n# of the object. Computing these operations on the original mask is O(n).\n# Thus, using the RLE can result in substantial computational savings.\n#\n# The following API functions are defined:\n#  encode         - Encode binary masks using RLE.\n#  decode         - Decode binary masks encoded via RLE.\n#  merge          - Compute union or intersection of encoded masks.\n#  iou            - Compute intersection over union between masks.\n#  area           - Compute area of encoded masks.\n#  toBbox         - Get bounding boxes surrounding encoded masks.\n#  frPyObjects    - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask.\n#\n# Usage:\n#  Rs     = encode( masks )\n#  masks  = decode( Rs )\n#  R      = merge( Rs, intersect=false )\n#  o      = iou( dt, gt, iscrowd )\n#  a      = area( Rs )\n#  bbs    = toBbox( Rs )\n#  Rs     = frPyObjects( [pyObjects], h, w )\n#\n# In the API the following formats are used:\n#  Rs      - [dict] Run-length encoding of binary masks\n#  R       - dict Run-length encoding of binary mask\n#  masks   - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order)\n#  iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore\n#  bbs     - [nx4] Bounding box(es) stored as [x y w h]\n#  poly    - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list)\n#  dt,gt   - May be either bounding boxes or encoded masks\n# Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel).\n#\n# Finally, a note about the intersection over union (iou) computation.\n# The standard iou of a ground truth (gt) and detected (dt) object is\n#  iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt))\n# For \"crowd\" regions, we use a modified criteria. If a gt object is\n# marked as \"iscrowd\", we allow a dt to match any subregion of the gt.\n# Choosing gt' in the crowd gt that best matches the dt can be done using\n# gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing\n#  iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt)\n# For crowd gt regions we use this modified criteria above for the iou.\n#\n# To compile run \"python setup.py build_ext --inplace\"\n# Please do not contact us for help with compiling.\n#\n# Microsoft COCO Toolbox.      version 2.0\n# Data, paper, and tutorials available at:  http://mscoco.org/\n# Code written by Piotr Dollar and Tsung-Yi Lin, 2015.\n# Licensed under the Simplified BSD License [see coco/license.txt]\n\niou         = _mask.iou\nmerge       = _mask.merge\nfrPyObjects = _mask.frPyObjects\n\ndef encode(bimask):\n    if len(bimask.shape) == 3:\n        return _mask.encode(bimask)\n    elif len(bimask.shape) == 2:\n        h, w = bimask.shape\n        return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]\n\ndef decode(rleObjs):\n    if type(rleObjs) == list:\n        return _mask.decode(rleObjs)\n    else:\n        return _mask.decode([rleObjs])[:,:,0]\n\ndef area(rleObjs):\n    if type(rleObjs) == list:\n        return _mask.area(rleObjs)\n    else:\n        return _mask.area([rleObjs])[0]\n\ndef toBbox(rleObjs):\n    if type(rleObjs) == list:\n        return _mask.toBbox(rleObjs)\n    else:\n        return _mask.toBbox([rleObjs])[0]"
  },
  {
    "path": "external/ext/ytvos.py",
    "content": "__author__ = 'ychfan'\n# Interface for accessing the YouTubeVIS dataset.\n\n# The following API functions are defined:\n#  YTVOS       - YTVOS api class that loads YouTubeVIS annotation file and prepare data structures.\n#  decodeMask - Decode binary mask M encoded via run-length encoding.\n#  encodeMask - Encode binary mask M using run-length encoding.\n#  getAnnIds  - Get ann ids that satisfy given filter conditions.\n#  getCatIds  - Get cat ids that satisfy given filter conditions.\n#  getImgIds  - Get img ids that satisfy given filter conditions.\n#  loadAnns   - Load anns with the specified ids.\n#  loadCats   - Load cats with the specified ids.\n#  loadImgs   - Load imgs with the specified ids.\n#  annToMask  - Convert segmentation in an annotation to binary mask.\n#  loadRes    - Load algorithm results and create API for accessing them.\n\n# Microsoft COCO Toolbox.      version 2.0\n# Data, paper, and tutorials available at:  http://mscoco.org/\n# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.\n# Licensed under the Simplified BSD License [see bsd.txt]\n\nimport json\nimport time\nimport matplotlib.pyplot as plt\nfrom matplotlib.collections import PatchCollection\nfrom matplotlib.patches import Polygon\nimport numpy as np\nimport copy\nimport itertools\nfrom . import mask as maskUtils\nimport os\nfrom collections import defaultdict\nimport sys\nPYTHON_VERSION = sys.version_info[0]\n\n\ndef _isArrayLike(obj):\n    return hasattr(obj, '__iter__') and hasattr(obj, '__len__')\n\n\nclass YTVOS:\n    def __init__(self, annotation_file=None):\n        \"\"\"\n        Constructor of Microsoft COCO helper class for reading and visualizing annotations.\n        :param annotation_file (str): location of annotation file\n        :param image_folder (str): location to the folder that hosts images.\n        :return:\n        \"\"\"\n        # load dataset\n        self.dataset,self.anns,self.cats,self.vids = dict(),dict(),dict(),dict()\n        self.vidToAnns, self.catToVids = defaultdict(list), defaultdict(list)\n        if not annotation_file == None:\n            print('loading annotations into memory...')\n            tic = time.time()\n            dataset = json.load(open(annotation_file, 'r'))\n            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))\n            print('Done (t={:0.2f}s)'.format(time.time()- tic))\n            self.dataset = dataset\n            self.createIndex()\n\n    def createIndex(self):\n        # create index\n        print('creating index...')\n        anns, cats, vids = {}, {}, {}\n        vidToAnns,catToVids = defaultdict(list),defaultdict(list)\n        if 'annotations' in self.dataset:\n            for ann in self.dataset['annotations']:\n                vidToAnns[ann['video_id']].append(ann)\n                anns[ann['id']] = ann\n\n        if 'videos' in self.dataset:\n            for vid in self.dataset['videos']:\n                vids[vid['id']] = vid\n\n        if 'categories' in self.dataset:\n            for cat in self.dataset['categories']:\n                cats[cat['id']] = cat\n\n        if 'annotations' in self.dataset and 'categories' in self.dataset:\n            for ann in self.dataset['annotations']:\n                catToVids[ann['category_id']].append(ann['video_id'])\n\n        print('index created!')\n\n        # create class members\n        self.anns = anns\n        self.vidToAnns = vidToAnns\n        self.catToVids = catToVids\n        self.vids = vids\n        self.cats = cats\n\n    def info(self):\n        \"\"\"\n        Print information about the annotation file.\n        :return:\n        \"\"\"\n        for key, value in self.dataset['info'].items():\n            print('{}: {}'.format(key, value))\n\n    def getAnnIds(self, vidIds=[], catIds=[], areaRng=[], iscrowd=None):\n        \"\"\"\n        Get ann ids that satisfy given filter conditions. default skips that filter\n        :param vidIds  (int array)     : get anns for given vids\n               catIds  (int array)     : get anns for given cats\n               areaRng (float array)   : get anns for given area range (e.g. [0 inf])\n               iscrowd (boolean)       : get anns for given crowd label (False or True)\n        :return: ids (int array)       : integer array of ann ids\n        \"\"\"\n        vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]\n        catIds = catIds if _isArrayLike(catIds) else [catIds]\n\n        if len(vidIds) == len(catIds) == len(areaRng) == 0:\n            anns = self.dataset['annotations']\n        else:\n            if not len(vidIds) == 0:\n                lists = [self.vidToAnns[vidId] for vidId in vidIds if vidId in self.vidToAnns]\n                anns = list(itertools.chain.from_iterable(lists))\n            else:\n                anns = self.dataset['annotations']\n            anns = anns if len(catIds)  == 0 else [ann for ann in anns if ann['category_id'] in catIds]\n            anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['avg_area'] > areaRng[0] and ann['avg_area'] < areaRng[1]]\n        if not iscrowd == None:\n            ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]\n        else:\n            ids = [ann['id'] for ann in anns]\n        return ids\n\n    def getCatIds(self, catNms=[], supNms=[], catIds=[]):\n        \"\"\"\n        filtering parameters. default skips that filter.\n        :param catNms (str array)  : get cats for given cat names\n        :param supNms (str array)  : get cats for given supercategory names\n        :param catIds (int array)  : get cats for given cat ids\n        :return: ids (int array)   : integer array of cat ids\n        \"\"\"\n        catNms = catNms if _isArrayLike(catNms) else [catNms]\n        supNms = supNms if _isArrayLike(supNms) else [supNms]\n        catIds = catIds if _isArrayLike(catIds) else [catIds]\n\n        if len(catNms) == len(supNms) == len(catIds) == 0:\n            cats = self.dataset['categories']\n        else:\n            cats = self.dataset['categories']\n            cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name']          in catNms]\n            cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]\n            cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id']            in catIds]\n        ids = [cat['id'] for cat in cats]\n        return ids\n\n    def     getVidIds(self, vidIds=[], catIds=[]):\n        '''\n        Get vid ids that satisfy given filter conditions.\n        :param vidIds (int array) : get vids for given ids\n        :param catIds (int array) : get vids with all given cats\n        :return: ids (int array)  : integer array of vid ids\n        '''\n        vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]\n        catIds = catIds if _isArrayLike(catIds) else [catIds]\n\n        if len(vidIds) == len(catIds) == 0:\n            ids = self.vids.keys()\n        else:\n            ids = set(vidIds)\n            for i, catId in enumerate(catIds):\n                if i == 0 and len(ids) == 0:\n                    ids = set(self.catToVids[catId])\n                else:\n                    ids &= set(self.catToVids[catId])\n        return list(ids)\n\n    def loadAnns(self, ids=[]):\n        \"\"\"\n        Load anns with the specified ids.\n        :param ids (int array)       : integer ids specifying anns\n        :return: anns (object array) : loaded ann objects\n        \"\"\"\n        if _isArrayLike(ids):\n            return [self.anns[id] for id in ids]\n        elif type(ids) == int:\n            return [self.anns[ids]]\n\n    def loadCats(self, ids=[]):\n        \"\"\"\n        Load cats with the specified ids.\n        :param ids (int array)       : integer ids specifying cats\n        :return: cats (object array) : loaded cat objects\n        \"\"\"\n        if _isArrayLike(ids):\n            return [self.cats[id] for id in ids]\n        elif type(ids) == int:\n            return [self.cats[ids]]\n\n    def loadVids(self, ids=[]):\n        \"\"\"\n        Load anns with the specified ids.\n        :param ids (int array)       : integer ids specifying vid\n        :return: vids (object array) : loaded vid objects\n        \"\"\"\n        if _isArrayLike(ids):\n            return [self.vids[id] for id in ids]\n        elif type(ids) == int:\n            return [self.vids[ids]]\n\n\n    def loadRes(self, resFile):\n        \"\"\"\n        Load result file and return a result api object.\n        :param   resFile (str)     : file name of result file\n        :return: res (obj)         : result api object\n        \"\"\"\n        res = YTVOS()\n        res.dataset['videos'] = [img for img in self.dataset['videos']]\n\n        print('Loading and preparing results...')\n        tic = time.time()\n        if type(resFile) == str or type(resFile) == unicode:\n            anns = json.load(open(resFile))\n        elif type(resFile) == np.ndarray:\n            anns = self.loadNumpyAnnotations(resFile)\n        else:\n            anns = resFile\n        assert type(anns) == list, 'results in not an array of objects'\n        annsVidIds = [ann['video_id'] for ann in anns]\n        assert set(annsVidIds) == (set(annsVidIds) & set(self.getVidIds())), \\\n               'Results do not correspond to current coco set'\n        if 'segmentations' in anns[0]:\n            res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])\n            for id, ann in enumerate(anns):\n                ann['areas'] = []\n                if not 'bboxes' in ann:\n                    ann['bboxes'] = []\n                for seg in ann['segmentations']:\n                    # now only support compressed RLE format as segmentation results\n                    if seg:\n                        ann['areas'].append(maskUtils.area(seg))\n                        if len(ann['bboxes']) < len(ann['areas']):\n                            ann['bboxes'].append(maskUtils.toBbox(seg))\n                    else:\n                        ann['areas'].append(None)\n                        if len(ann['bboxes']) < len(ann['areas']):\n                            ann['bboxes'].append(None)\n                ann['id'] = id+1\n                l = [a for a in ann['areas'] if a]\n                if len(l)==0:\n                  ann['avg_area'] = 0\n                else:\n                  ann['avg_area'] = np.array(l).mean()\n                ann['iscrowd'] = 0\n        print('DONE (t={:0.2f}s)'.format(time.time()- tic))\n\n        res.dataset['annotations'] = anns\n        res.createIndex()\n        return res\n\n    def annToRLE(self, ann, frameId):\n        \"\"\"\n        Convert annotation which can be polygons, uncompressed RLE to RLE.\n        :return: binary mask (numpy 2D array)\n        \"\"\"\n        t = self.vids[ann['video_id']]\n        h, w = t['height'], t['width']\n        segm = ann['segmentations'][frameId]\n        if type(segm) == list:\n            # polygon -- a single object might consist of multiple parts\n            # we merge all parts into one mask rle code\n            rles = maskUtils.frPyObjects(segm, h, w)\n            rle = maskUtils.merge(rles)\n        elif type(segm['counts']) == list:\n            # uncompressed RLE\n            rle = maskUtils.frPyObjects(segm, h, w)\n        else:\n            # rle\n            rle = segm\n        return rle\n\n    def annToMask(self, ann, frameId):\n        \"\"\"\n        Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.\n        :return: binary mask (numpy 2D array)\n        \"\"\"\n        rle = self.annToRLE(ann, frameId)\n        m = maskUtils.decode(rle)\n        return m"
  },
  {
    "path": "external/fcn_mask_head.py",
    "content": "import numpy as np\nimport torch\nfrom mmdet.models.builder import HEADS\nfrom mmdet.models.roi_heads.mask_heads.fcn_mask_head import (FCNMaskHead,\n                                                             _do_paste_mask)\n\nBYTES_PER_FLOAT = 4\n# TODO: This memory limit may be too much or too little. It would be better to\n# determine it based on available resources.\nGPU_MEM_LIMIT = 1024**3  # 1 GB memory limit\n\n\n@HEADS.register_module()\nclass InstanceMaskHead(FCNMaskHead):\n\n    def __init__(self, **kwargs):\n        super(InstanceMaskHead, self).__init__(**kwargs)\n\n    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,\n                      ori_shape, scale_factor, rescale):\n        \"\"\"Get segmentation masks from mask_pred and bboxes.\n\n        The only difference from InstanceMaskHead and FCNMaskHead is the output\n        format of instance masks. The original FCNMaskHead return numpy masks.\n\n        Args:\n            mask_pred (Tensor or ndarray): shape (n, #class, h, w).\n                For single-scale testing, mask_pred is the direct output of\n                model, whose type is Tensor, while for multi-scale testing,\n                it will be converted to numpy array outside of this method.\n            det_bboxes (Tensor): shape (n, 4/5)\n            det_labels (Tensor): shape (n, )\n            rcnn_test_cfg (dict): rcnn testing config\n            ori_shape (Tuple): original image height and width, shape (2,)\n            scale_factor(float | Tensor): If ``rescale is True``, box\n                coordinates are divided by this scale factor to fit\n                ``ori_shape``.\n            rescale (bool): If True, the resulting masks will be rescaled to\n                ``ori_shape``.\n\n        Returns:\n            list[list]: encoded masks. The c-th item in the outer list\n                corresponds to the c-th class. Given the c-th outer list, the\n                i-th item in that inner list is the mask for the i-th box with\n                class label c.\n\n        Example:\n            >>> import mmcv\n            >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import *  # NOQA\n            >>> N = 7  # N = number of extracted ROIs\n            >>> C, H, W = 11, 32, 32\n            >>> # Create example instance of FCN Mask Head.\n            >>> self = FCNMaskHead(num_classes=C, num_convs=0)\n            >>> inputs = torch.rand(N, self.in_channels, H, W)\n            >>> mask_pred = self.forward(inputs)\n            >>> # Each input is associated with some bounding box\n            >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)\n            >>> det_labels = torch.randint(0, C, size=(N,))\n            >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })\n            >>> ori_shape = (H * 4, W * 4)\n            >>> scale_factor = torch.FloatTensor((1, 1))\n            >>> rescale = False\n            >>> # Encoded masks are a list for each category.\n            >>> encoded_masks = self.get_seg_masks(\n            >>>     mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,\n            >>>     scale_factor, rescale\n            >>> )\n            >>> assert len(encoded_masks) == C\n            >>> assert sum(list(map(len, encoded_masks))) == N\n        \"\"\"\n        if isinstance(mask_pred, torch.Tensor):\n            mask_pred = mask_pred.sigmoid()\n        else:\n            mask_pred = det_bboxes.new_tensor(mask_pred)\n\n        device = mask_pred.device\n        bboxes = det_bboxes[:, :4]\n        labels = det_labels\n\n        if rescale:\n            img_h, img_w = ori_shape[:2]\n        else:\n            if isinstance(scale_factor, float):\n                img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)\n                img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)\n            else:\n                w_scale, h_scale = scale_factor[0], scale_factor[1]\n                img_h = np.round(ori_shape[0] * h_scale.item()).astype(\n                    np.int32)\n                img_w = np.round(ori_shape[1] * w_scale.item()).astype(\n                    np.int32)\n            scale_factor = 1.0\n\n        if not isinstance(scale_factor, (float, torch.Tensor)):\n            scale_factor = bboxes.new_tensor(scale_factor)\n        bboxes = bboxes / scale_factor\n\n        if torch.onnx.is_in_onnx_export():\n            # TODO: Remove after F.grid_sample is supported.\n            from torchvision.models.detection.roi_heads \\\n                import paste_masks_in_image\n            masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])\n            thr = rcnn_test_cfg.get('mask_thr_binary', 0)\n            if thr > 0:\n                masks = masks >= thr\n            return masks\n\n        N = len(mask_pred)\n        # The actual implementation split the input into chunks,\n        # and paste them chunk by chunk.\n        if device.type == 'cpu':\n            # CPU is most efficient when they are pasted one by one with\n            # skip_empty=True, so that it performs minimal number of\n            # operations.\n            num_chunks = N\n        else:\n            # GPU benefits from parallelism for larger chunks,\n            # but may have memory issue\n            num_chunks = int(\n                np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))\n            assert (num_chunks <=\n                    N), 'Default GPU_MEM_LIMIT is too small; try increasing it'\n        chunks = torch.chunk(torch.arange(N, device=device), num_chunks)\n\n        threshold = rcnn_test_cfg.mask_thr_binary\n        im_mask = torch.zeros(\n            N,\n            img_h,\n            img_w,\n            device=device,\n            dtype=torch.bool if threshold >= 0 else torch.uint8)\n\n        if not self.class_agnostic:\n            mask_pred = mask_pred[range(N), labels][:, None]\n\n        for inds in chunks:\n            masks_chunk, spatial_inds = _do_paste_mask(\n                mask_pred[inds],\n                bboxes[inds],\n                img_h,\n                img_w,\n                skip_empty=device.type == 'cpu')\n\n            if threshold >= 0:\n                masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)\n            else:\n                # for visualization and debugging\n                masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)\n\n            im_mask[(inds, ) + spatial_inds] = masks_chunk\n\n        return im_mask\n"
  },
  {
    "path": "external/kitti_step_dvps.py",
    "content": "import os\nimport random\nfrom typing import Dict, List\n\nimport copy\n\nimport mmcv\nimport numpy as np\nimport torch\n\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.pipelines import Compose\nfrom mmdet.datasets import CustomDataset\nfrom mmdet.utils import get_root_logger\n\nfrom external.dataset.mIoU import eval_miou\n\n\nclass SeqObj:\n    # This divisor is orthogonal with panoptic class-instance divisor.\n    DIVISOR = 1000000\n\n    def __init__(self, the_dict: Dict):\n        self.dict = the_dict\n        assert 'seq_id' in self.dict and 'img_id' in self.dict\n\n    def __hash__(self):\n        return self.dict['seq_id'] * self.DIVISOR + self.dict['img_id']\n\n    def __eq__(self, other):\n        return self.dict['seq_id'] == other.dict['seq_id'] and self.dict['img_id'] == other.dict['img_id']\n\n    def __getitem__(self, attr):\n        return self.dict[attr]\n\n\n@DATASETS.register_module()\nclass KITTISTEPDVPSDataset:\n    CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',\n               'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',\n               'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n               'bicycle')\n\n    def __init__(self,\n                 pipeline=None,\n                 data_root=None,\n                 test_mode=False,\n                 split='train',\n                 ref_seq_index: List[int] = None,\n                 is_instance_only: bool = True,\n                 with_depth: bool = False\n                 ):\n        assert data_root is not None\n        data_root = os.path.expanduser(data_root)\n        video_seq_dir = os.path.join(data_root, 'video_sequence', split)\n        assert os.path.exists(video_seq_dir)\n        assert 'leftImg8bit' not in video_seq_dir\n\n        self.num_thing_classes = 2\n        self.num_stuff_classes = 17\n        self.thing_before_stuff = False\n\n        # ref_seq_index is None means no ref img\n        if ref_seq_index is None:\n            ref_seq_index = []\n\n        filenames = list(map(lambda x: str(x), os.listdir(video_seq_dir)))\n        img_names = sorted(list(filter(lambda x: 'leftImg8bit' in x, filenames)))\n\n        images = []\n        for item in img_names:\n            seq_id, img_id, _ = item.split(sep=\"_\", maxsplit=2)\n            if int(seq_id) == 1 and int(img_id) in [177, 178, 179, 180] and with_depth:\n                continue\n            item_full = os.path.join(video_seq_dir, item)\n            images.append(SeqObj({\n                'seq_id': int(seq_id),\n                'img_id': int(img_id),\n                'img': item_full,\n                'depth': item_full.replace('leftImg8bit', 'depth') if with_depth else None,\n                'ann': item_full.replace('leftImg8bit', 'panoptic'),\n                # This should be modified carefully for each dataset. Usually 255.\n                'no_obj_class': 255\n            }))\n            assert os.path.exists(images[-1]['img'])\n            assert images[-1]['depth'] is None or os.path.exists(images[-1]['depth']), \\\n                \"Missing depth : {}\".format(images[-1]['depth'])\n            # assert os.path.exists(images[-1]['ann'])\n\n        reference_images = {hash(image): image for image in images}\n        sequences = []\n        for img_cur in images:\n            is_seq = True\n            seq_now = [img_cur.dict]\n            if ref_seq_index:\n                for index in random.choices(ref_seq_index, k=1):\n                    query_obj = SeqObj({\n                        'seq_id': img_cur.dict['seq_id'],\n                        'img_id': img_cur.dict['img_id'] + index\n                    })\n                    if hash(query_obj) in reference_images:\n                        seq_now.append(reference_images[hash(query_obj)].dict)\n                    else:\n                        is_seq = False\n                        break\n            if is_seq:\n                sequences.append(seq_now)\n\n        self.sequences = sequences\n        self.ref_seq_index = ref_seq_index\n\n        # mmdet\n        self.pipeline = Compose(pipeline)\n        self.test_mode = test_mode\n\n        # misc\n        self.flag = self._set_groups()\n        self.is_instance_only = is_instance_only\n\n        # For evaluation\n        self.max_ins = 10000\n        self.no_obj_id = 255\n\n    def pre_pipelines(self, results):\n        for _results in results:\n            _results['img_info'] = []\n            _results['thing_lower'] = 0 if self.thing_before_stuff else self.num_stuff_classes\n            _results['thing_upper'] = self.num_thing_classes \\\n                if self.thing_before_stuff else self.num_stuff_classes + self.num_thing_classes\n            _results['is_instance_only'] = self.is_instance_only\n            _results['ori_filename'] = os.path.basename(_results['img'])\n\n    def prepare_train_img(self, idx):\n        \"\"\"Get training data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training data and annotation after pipeline with new keys \\\n                introduced by pipeline.\n        \"\"\"\n        results = copy.deepcopy(self.sequences[idx])\n        self.pre_pipelines(results)\n        return self.pipeline(results)\n\n    def prepare_test_img(self, idx):\n        results = copy.deepcopy(self.sequences[idx])\n        self.pre_pipelines(results)\n        # During test time, one image inference does not requires seq\n        if not self.ref_seq_index:\n            results = results[0]\n        return self.pipeline(results)\n\n    def _rand_another(self, idx):\n        \"\"\"Get another random index from the same group as the given index.\"\"\"\n        pool = np.where(self.flag == self.flag[idx])[0]\n        return np.random.choice(pool)\n\n    # Copy and Modify from mmdet\n    def __getitem__(self, idx):\n        \"\"\"Get training/test data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training/test data (with annotation if `test_mode` is set \\\n                True).\n        \"\"\"\n\n        if self.test_mode:\n            return self.prepare_test_img(idx)\n        else:\n            while True:\n                cur_data = self.prepare_train_img(idx)\n                if cur_data is None:\n                    idx = self._rand_another(idx)\n                    continue\n                return cur_data\n\n    def __len__(self):\n        \"\"\"Total number of samples of data.\"\"\"\n        return len(self.sequences)\n\n    def _set_groups(self):\n        return np.zeros((len(self)), dtype=np.int64)\n\n    # The evaluate func\n    def evaluate(\n            self,\n            results,\n            **kwargs\n    ):\n        # logger and metric\n        thing_knet2real = [11, 13]\n        pred_results_handled = []\n        pred_depth = []\n        pred_depth_final = []\n        item_id = 0\n        sem_preds = []\n        for item in results:\n            if item[-1] is not None:\n                # With depth\n                bbox_results, mask_results, seg_results, depth, depth_final = item\n                pred_depth.append(depth)\n                pred_depth_final.append(depth_final)\n            else:\n                bbox_results, mask_results, seg_results, _, _ = item\n            # in seg_info id starts from 1\n            inst_map, seg_info = seg_results\n            cat_map = np.zeros_like(inst_map) + self.num_thing_classes + self.num_stuff_classes\n            for instance in seg_info:\n                cat_cur = instance['category_id']\n                if instance['isthing']:\n                    cat_cur = thing_knet2real[cat_cur]\n                else:\n                    if self.thing_before_stuff:\n                        raise NotImplementedError\n                    else:\n                        # stuff starts from 1 in the model\n                        cat_cur -= 1\n                        offset = 0\n                        for thing_id in thing_knet2real:\n                            if cat_cur + offset >= thing_id:\n                                offset += 1\n                        cat_cur += offset\n                assert cat_cur < self.num_thing_classes + self.num_stuff_classes\n                cat_map[inst_map == instance['id']] = cat_cur\n                if not instance['isthing']:\n                    inst_map[inst_map == instance['id']] = 0\n            pred_results_handled.append(cat_map.astype(np.int32) * self.max_ins + inst_map.astype(np.int32))\n            item_id += 1\n            sem_preds.append(cat_map)\n\n        gt_panseg = []\n        gt_depth = []\n        sem_targets = []\n        for item in self.sequences:\n            # Only for single\n            item = item[0]\n            # Only for single\n            id_map = mmcv.imread(item['ann'], flag='color', channel_order='rgb')\n            gt_semantic_seg = id_map[..., 0].astype(np.int32)\n            sem_targets.append(gt_semantic_seg)\n            gt_inst_map = id_map[..., 1].astype(np.int32) * 256 + id_map[..., 2].astype(np.int32)\n            ps_id = gt_semantic_seg * self.max_ins + gt_inst_map\n            gt_panseg.append(ps_id)\n            if len(pred_depth) > 0:\n                gt_depth_cur = mmcv.imread(item['depth'], flag='unchanged').astype(np.float32) / 256.\n                gt_depth.append(gt_depth_cur)\n\n        vpq_results = []\n        for pred, gt in zip(pred_results_handled, gt_panseg):\n            vpq_result = vpq_eval([pred, gt])\n            vpq_results.append(vpq_result)\n\n        iou_per_class = np.stack([result[0] for result in vpq_results]).sum(axis=0)[\n                        :self.num_thing_classes + self.num_stuff_classes]\n        tp_per_class = np.stack([result[1] for result in vpq_results]).sum(axis=0)[\n                       :self.num_thing_classes + self.num_stuff_classes]\n        fn_per_class = np.stack([result[2] for result in vpq_results]).sum(axis=0)[\n                       :self.num_thing_classes + self.num_stuff_classes]\n        fp_per_class = np.stack([result[3] for result in vpq_results]).sum(axis=0)[\n                       :self.num_thing_classes + self.num_stuff_classes]\n\n        abs_rels = []\n        abs_rel_finals = []\n        if len(pred_depth) > 0:\n            for pred, pred_final, gt in zip(pred_depth, pred_depth_final, gt_depth):\n                depth_mask = gt > 0.\n                abs_rel_normal = np.mean(\n                    np.abs(\n                        pred[depth_mask] -\n                        gt[depth_mask]) /\n                    gt[depth_mask])\n                abs_rel_final = np.mean(\n                    np.abs(\n                        pred_final[depth_mask] -\n                        gt[depth_mask]) /\n                    gt[depth_mask])\n                abs_rels.append(abs_rel_normal)\n                abs_rel_finals.append(abs_rel_final)\n            abs_rel = np.stack(abs_rels).mean(axis=0)\n            abs_rel_final = np.stack(abs_rel_finals).mean(axis=0)\n        else:\n            abs_rel = 0.\n            abs_rel_final = 0.\n\n        # calculate the PQs\n        epsilon = 0.\n        sq = iou_per_class / (tp_per_class + epsilon)\n        rq = tp_per_class / (tp_per_class + 0.5 *\n                             fn_per_class + 0.5 * fp_per_class + epsilon)\n        pq = sq * rq\n        things_index = np.zeros((19,)).astype(bool)\n        things_index[11] = True\n        things_index[13] = True\n        stuff_pq = pq[np.logical_not(things_index)]\n        things_pq = pq[things_index]\n\n        miou_per_class = eval_miou(sem_preds, sem_targets, num_classes=self.num_thing_classes + self.num_stuff_classes)\n        print(\"class        pq\\t\\tsq\\t\\trq\\t\\ttp\\t\\tfp\\t\\tfn\\t\\tmIoU\")\n\n        for i in range(len(self.CLASSES)):\n            print(\"{}{}{:.3f}\\t\\t{:.3f}\\t\\t{:.3f}\\t\\t{:.0f}\\t\\t{:.0f}\\t\\t{:.0f}\\t\\t{:.3f}\".format(\n                self.CLASSES[i], ' ' * (13 - len(self.CLASSES[i])), pq[i], sq[i], rq[i], tp_per_class[i],\n                fp_per_class[i], fn_per_class[i], miou_per_class[i]\n            ))\n\n        return {\n            \"abs_rel\": abs_rel,\n            \"abs_rel_final\": abs_rel_final,\n            \"PQ\": np.nan_to_num(pq).mean() * 100,\n            \"Stuff PQ\": np.nan_to_num(stuff_pq).mean() * 100,\n            \"Things PQ\": np.nan_to_num(things_pq).mean() * 100,\n            \"mIoU\": np.nan_to_num(miou_per_class).mean() * 100,\n        }\n\n\ndef vpq_eval(element):\n    import six\n    pred_ids, gt_ids = element\n    max_ins = 10000\n    ign_id = 255\n    offset = 2 ** 30\n    num_cat = 19 + 1\n\n    iou_per_class = np.zeros(num_cat, dtype=np.float64)\n    tp_per_class = np.zeros(num_cat, dtype=np.float64)\n    fn_per_class = np.zeros(num_cat, dtype=np.float64)\n    fp_per_class = np.zeros(num_cat, dtype=np.float64)\n\n    def _ids_to_counts(id_array):\n        ids, counts = np.unique(id_array, return_counts=True)\n        return dict(six.moves.zip(ids, counts))\n\n    pred_areas = _ids_to_counts(pred_ids)\n    gt_areas = _ids_to_counts(gt_ids)\n\n    void_id = ign_id * max_ins\n    ign_ids = {\n        gt_id for gt_id in six.iterkeys(gt_areas)\n        if (gt_id // max_ins) == ign_id\n    }\n\n    int_ids = gt_ids.astype(np.int64) * offset + pred_ids.astype(np.int64)\n    int_areas = _ids_to_counts(int_ids)\n\n    def prediction_void_overlap(pred_id):\n        void_int_id = void_id * offset + pred_id\n        return int_areas.get(void_int_id, 0)\n\n    def prediction_ignored_overlap(pred_id):\n        total_ignored_overlap = 0\n        for _ign_id in ign_ids:\n            int_id = _ign_id * offset + pred_id\n            total_ignored_overlap += int_areas.get(int_id, 0)\n        return total_ignored_overlap\n\n    gt_matched = set()\n    pred_matched = set()\n\n    for int_id, int_area in six.iteritems(int_areas):\n        gt_id = int(int_id // offset)\n        gt_cat = int(gt_id // max_ins)\n        pred_id = int(int_id % offset)\n        pred_cat = int(pred_id // max_ins)\n        if gt_cat != pred_cat:\n            continue\n        union = (\n                gt_areas[gt_id] + pred_areas[pred_id] - int_area -\n                prediction_void_overlap(pred_id)\n        )\n        iou = int_area / union\n        if iou > 0.5:\n            tp_per_class[gt_cat] += 1\n            iou_per_class[gt_cat] += iou\n            gt_matched.add(gt_id)\n            pred_matched.add(pred_id)\n\n    for gt_id in six.iterkeys(gt_areas):\n        if gt_id in gt_matched:\n            continue\n        cat_id = gt_id // max_ins\n        if cat_id == ign_id:\n            continue\n        fn_per_class[cat_id] += 1\n\n    for pred_id in six.iterkeys(pred_areas):\n        if pred_id in pred_matched:\n            continue\n        if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5:\n            continue\n        cat = pred_id // max_ins\n        fp_per_class[cat] += 1\n\n    return iou_per_class, tp_per_class, fn_per_class, fp_per_class\n\n\nif __name__ == '__main__':\n    import dataset.dvps_pipelines.loading\n    import dataset.dvps_pipelines.transforms\n    import dataset.pipelines.transforms\n    import dataset.pipelines.formatting\n\n    img_norm_cfg = dict(\n        mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\n    test_pipeline = [\n        dict(type='LoadMultiImagesDirect'),\n        dict(type='SeqPadWithDepth', size_divisor=32),\n        dict(type='SeqNormalize', **img_norm_cfg),\n        dict(\n            type='VideoCollect',\n            keys=['img']),\n        dict(type='ConcatVideoReferences'),\n        dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n    ]\n\n    data = KITTISTEPDVPSDataset(\n        pipeline=[\n            dict(type='LoadMultiImagesDirect'),\n            dict(type='LoadMultiAnnotationsDirect', with_depth=True, divisor=-1),\n            dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n            dict(type='SeqPadWithDepth', size_divisor=32),\n            dict(type='SeqNormalize', **img_norm_cfg),\n            dict(\n                type='VideoCollect',\n                keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_depth']),\n            dict(type='ConcatVideoReferences'),\n            dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n        ],\n        data_root=os.path.expanduser('~/datasets/kitti-step'),\n        split='val',\n        ref_seq_index=[-1, 1],\n        with_depth=True,\n    )\n    np.set_string_function(lambda x: '<{} ; {}>'.format(x.shape, x.dtype))\n    torch.set_printoptions(profile='short')\n    for item in data:\n        print(item)\n"
  },
  {
    "path": "external/panoptic_fpn.py",
    "content": "from mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors.two_stage import TwoStageDetector\n\n\n@DETECTORS.register_module()\nclass PanopticFPN(TwoStageDetector):\n    \"\"\"Implementation of `Panoptic FPN <https://arxiv.org/abs/1901.02446>`_\"\"\"\n\n    def __init__(self,\n                 backbone,\n                 rpn_head,\n                 roi_head,\n                 train_cfg,\n                 test_cfg,\n                 neck=None,\n                 pretrained=None):\n        super(PanopticFPN, self).__init__(\n            backbone=backbone,\n            neck=neck,\n            rpn_head=rpn_head,\n            roi_head=roi_head,\n            train_cfg=train_cfg,\n            test_cfg=test_cfg,\n            pretrained=pretrained)\n\n    @property\n    def with_semantic(self):\n        \"\"\"bool: whether the detector has a semantic head\"\"\"\n        return ((hasattr(self, 'roi_head') and self.roi_head.with_semantic)\n                or (hasattr(self, 'semantic_head')\n                    and self.semantic_head is not None))\n"
  },
  {
    "path": "external/panoptic_head.py",
    "content": "import torch\nfrom mmdet.core import bbox2result\nfrom mmdet.models.builder import HEADS, build_head\nfrom mmdet.models.roi_heads import StandardRoIHead\n\n\nclass PanopticTestMixin(object):\n\n    def simple_test_semantic(self, x, img_metas):\n        segm_feature_pred = self.semantic_head(x)\n        semantic_seg_results = []\n        for i, img_meta in enumerate(img_metas):\n            semantic_seg_results.append(\n                self.semantic_head.get_semantic_seg(segm_feature_pred[i:i + 1],\n                                                    img_meta['ori_shape'],\n                                                    img_meta['img_shape'])[0])\n\n        return semantic_seg_results\n\n    def generate_panoptic(self, det_bboxes, det_labels, mask_preds, sem_seg,\n                          img_metas, merge_cfg):\n        panoptic_results = []\n        for i in range(len(img_metas)):\n            panoptic_results.append(\n                merge_stuff_thing(det_bboxes[i], det_labels[i], mask_preds[i],\n                                  sem_seg[i], merge_cfg))\n        return panoptic_results\n\n\n@HEADS.register_module()\nclass PanopticHead(StandardRoIHead, PanopticTestMixin):\n    \"\"\"Panoptic Segmentation Head for Panoptic Seg.\"\"\"\n\n    def __init__(self, *args, semantic_head, **kwargs):\n        super(PanopticHead, self).__init__(*args, **kwargs)\n        self.semantic_head = build_head(semantic_head)\n\n    @property\n    def with_semantic(self):\n        \"\"\"bool: whether the head has semantic head\"\"\"\n        if hasattr(self, 'semantic_head') and self.semantic_head is not None:\n            return True\n        else:\n            return False\n\n    def init_weights(self, pretrained):\n        \"\"\"Initialize the weights in head.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        super().init_weights(pretrained)\n        if self.with_semantic:\n            self.semantic_head.init_weights()\n\n    def forward_train(self,\n                      x,\n                      img_metas,\n                      proposal_list,\n                      gt_bboxes,\n                      gt_labels,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None):\n        \"\"\"\n        Args:\n            x (list[Tensor]): list of multi-level img features.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmdet/datasets/pipelines/formatting.py:Collect`.\n            proposals (list[Tensors]): list of region proposals.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor]): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (None | Tensor) : true segmentation masks for each box\n                used if the architecture supports a segmentation task.\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        # assign gts and sample proposals\n        if self.with_bbox or self.with_mask:\n            num_imgs = len(img_metas)\n            if gt_bboxes_ignore is None:\n                gt_bboxes_ignore = [None for _ in range(num_imgs)]\n            sampling_results = []\n            for i in range(num_imgs):\n                assign_result = self.bbox_assigner.assign(\n                    proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],\n                    gt_labels[i])\n                sampling_result = self.bbox_sampler.sample(\n                    assign_result,\n                    proposal_list[i],\n                    gt_bboxes[i],\n                    gt_labels[i],\n                    feats=[lvl_feat[i][None] for lvl_feat in x])\n                sampling_results.append(sampling_result)\n\n        losses = dict()\n        # bbox head forward and loss\n        if self.with_bbox:\n            bbox_results = self._bbox_forward_train(x, sampling_results,\n                                                    gt_bboxes, gt_labels,\n                                                    img_metas)\n            losses.update(bbox_results['loss_bbox'])\n\n        # mask head forward and loss\n        if self.with_mask:\n            mask_results = self._mask_forward_train(x, sampling_results,\n                                                    bbox_results['bbox_feats'],\n                                                    gt_masks, img_metas)\n            losses.update(mask_results['loss_mask'])\n\n        if self.with_semantic:\n            for i in range(gt_semantic_seg.shape[0]):\n                gt_semantic_seg[i, :, img_metas[i]['img_shape']\n                                [0]:, :] = self.semantic_head.ignore_label\n                gt_semantic_seg[i, :, :, img_metas[i]['img_shape']\n                                [1]:] = self.semantic_head.ignore_label\n            seg_preds = self.semantic_head(x)\n            seg_losses = self.semantic_head.loss(seg_preds, gt_semantic_seg)\n            losses.update(seg_losses)\n\n        return losses\n\n    async def async_simple_test(self,\n                                x,\n                                proposal_list,\n                                img_metas,\n                                proposals=None,\n                                rescale=False):\n        \"\"\"Async test without augmentation.\"\"\"\n        raise NotImplementedError('PanopticHead does not support async test')\n\n    def simple_test(self,\n                    x,\n                    proposal_list,\n                    img_metas,\n                    proposals=None,\n                    rescale=False):\n        \"\"\"Test without augmentation.\"\"\"\n        assert self.with_bbox, 'Bbox head must be implemented.'\n\n        det_bboxes, det_labels = self.simple_test_bboxes(\n            x, img_metas, proposal_list, self.test_cfg, rescale=rescale)\n        if torch.onnx.is_in_onnx_export():\n            if self.with_mask:\n                segm_results = self.simple_test_mask(\n                    x, img_metas, det_bboxes, det_labels, rescale=rescale)\n                return det_bboxes, det_labels, segm_results\n            else:\n                return det_bboxes, det_labels\n\n        bbox_results = [\n            bbox2result(det_bboxes[i], det_labels[i],\n                        self.bbox_head.num_classes)\n            for i in range(len(det_bboxes))\n        ]\n\n        if not self.with_mask:\n            return bbox_results\n        else:\n            mask_preds = self.simple_test_mask(\n                x, img_metas, det_bboxes, det_labels, rescale=rescale)\n            segm_results = mask2result(mask_preds, det_labels,\n                                       self.mask_head.num_classes)\n\n            if self.with_semantic:\n                sem_seg = self.simple_test_semantic(x, img_metas)\n                panoptic_results = self.generate_panoptic(\n                    det_bboxes, det_labels, mask_preds, sem_seg, img_metas,\n                    self.test_cfg.merge_stuff_thing)\n                return list(zip(bbox_results, segm_results, panoptic_results))\n            return list(zip(bbox_results, segm_results))\n\n\ndef mask2result(mask_preds, labels, num_classes):\n    cls_segms = []\n    for batch_id, mask_pred in enumerate(mask_preds):\n        if isinstance(mask_pred, list):\n            cls_segms.append(mask_pred)\n            continue\n        cls_segms.append([[] for _ in range(num_classes)])\n        N = mask_preds[batch_id].shape[0]\n        for i in range(N):\n            cls_segms[batch_id][labels[batch_id][i]].append(\n                mask_pred[i].detach().cpu().numpy())\n    return cls_segms\n\n\ndef merge_stuff_thing(det_bboxes,\n                      det_labels,\n                      mask_preds,\n                      sem_seg,\n                      merge_cfg=None):\n    \"\"\"Merge stuff and thing segmentation maps.\n\n    This function is modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/meta_arch/panoptic_fpn.py#L183  # noqa\n\n    Args:\n        det_bboxes  (torch.Tensor): Bounding boxes in shape (n, 5).\n        det_labels (torch.Tensor): Labels of bounding boxes in shape (n, ).\n        mask_preds (torch.Tensor): Mask prediction in the original image size.\n        sem_seg (torch.Tensor): Semantic segmentation prediction in the original\n            image size.\n        merge_cfg (dict): The config dict containing merge hyper-parameters.\n    \"\"\"\n    sem_seg = sem_seg.argmax(dim=0)\n    box_scores = det_bboxes[:, -1]\n    panoptic_seg = torch.zeros_like(sem_seg, dtype=torch.int32)\n\n    # sort instance outputs by scores\n    sorted_inds = torch.argsort(-box_scores)\n\n    current_segment_id = 0\n    segments_info = []\n\n    if isinstance(mask_preds, list):\n        instance_masks = None\n    else:\n        instance_masks = mask_preds.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n\n    # Add instances one-by-one, check for overlaps with existing ones\n    for inst_id in sorted_inds:\n        score = box_scores[inst_id].item()\n        if score < merge_cfg.instance_score_thr:\n            break\n        mask = instance_masks[inst_id]  # H,W\n        mask_area = mask.sum().item()\n\n        if mask_area == 0:\n            continue\n\n        intersect = (mask > 0) & (panoptic_seg > 0)\n        intersect_area = intersect.sum().item()\n\n        if intersect_area * 1.0 / mask_area > merge_cfg.iou_thr:\n            continue\n\n        if intersect_area > 0:\n            mask = mask & (panoptic_seg == 0)\n\n        current_segment_id += 1\n        panoptic_seg[mask] = current_segment_id\n        segments_info.append({\n            'id': current_segment_id,\n            'isthing': True,\n            'score': score,\n            'category_id': det_labels[inst_id].item(),\n            'instance_id': inst_id.item(),\n        })\n\n    # Add semantic results to remaining empty areas\n    semantic_labels = torch.unique(sem_seg).cpu().tolist()\n    for semantic_label in semantic_labels:\n        if semantic_label == 0:  # 0 is a special \"thing\" class\n            continue\n        mask = (sem_seg == semantic_label) & (panoptic_seg == 0)\n        mask_area = mask.sum().item()\n        if mask_area < merge_cfg.stuff_max_area:\n            continue\n\n        current_segment_id += 1\n        panoptic_seg[mask] = current_segment_id\n        segments_info.append({\n            'id': current_segment_id,\n            'isthing': False,\n            'category_id': semantic_label,\n            'area': mask_area,\n        })\n\n    return panoptic_seg.cpu().numpy(), segments_info\n"
  },
  {
    "path": "external/semantic_seg_head.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import kaiming_init\nfrom mmcv.runner import auto_fp16, force_fp32\nfrom mmdet.models.builder import HEADS, build_loss, build_neck\nfrom mmdet.models.roi_heads.mask_heads import FusedSemanticHead\n\n\n@HEADS.register_module()\nclass SemanticHead(FusedSemanticHead):\n    \"\"\"Semantic segmentation head that can be used in panoptic segmentation.\n\n    Args:\n        semantic_decoder (dict): Config dict of decoder.\n            It usually is a neck, like semantic FPN.\n        in_channels (int, optional): Input channels. Defaults to 256.\n        num_classes (int, optional):  Number of semantic classes including\n            the background. Defaults to 183.\n        ignore_label (int, optional): Labels to be ignored. Defaults to 255.\n        loss_seg (dict, optional): Config dict of loss.\n            Defaults to `dict(type='CrossEntropyLoss', use_sigmoid=False, \\\n            loss_weight=1.0)`.\n        conv_cfg (dict, optional): Config of convolutional layers.\n            Defaults to None.\n        norm_cfg (dict, optional): Config of normalization layers.\n            Defaults to None.\n    \"\"\"\n\n    def __init__(self,\n                 semantic_decoder,\n                 in_channels=256,\n                 num_classes=183,\n                 ignore_label=255,\n                 pred_stride=4,\n                 loss_seg=dict(\n                     type='CrossEntropyLoss',\n                     use_sigmoid=False,\n                     loss_weight=1.0),\n                 conv_cfg=None,\n                 norm_cfg=None):\n        super(FusedSemanticHead, self).__init__()\n        self.semantic_decoder = build_neck(semantic_decoder)\n        self.conv_logits = nn.Conv2d(in_channels, num_classes, 1)\n        self.loss_seg = build_loss(loss_seg)\n\n        self.in_channels = in_channels\n        self.num_classes = num_classes\n        self.ignore_label = ignore_label\n        self.pred_stride = pred_stride\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.fp16_enabled = False\n\n    def init_weights(self):\n        kaiming_init(self.conv_logits)\n\n    @auto_fp16()\n    def forward(self, feats):\n        x = self.semantic_decoder(feats)\n        mask_pred = self.conv_logits(x)\n        return mask_pred\n\n    @force_fp32(apply_to=('mask_pred', ))\n    def loss(self, mask_pred, labels):\n        mask_pred = F.interpolate(\n            mask_pred,\n            scale_factor=self.pred_stride,\n            mode='bilinear',\n            align_corners=False)\n        labels = labels.squeeze(1).long()\n        loss_sem_seg = self.loss_seg.loss_weight * F.cross_entropy(\n            mask_pred,\n            labels,\n            reduction='mean',\n            ignore_index=self.ignore_label)\n        # loss_semantic_seg = self.loss_seg(\n        #     mask_pred, labels, ignore_index=self.ignore_label)\n        return dict(loss_sem_seg=loss_sem_seg)\n\n    def get_semantic_seg(self, seg_preds, ori_shape, img_shape_withoutpad):\n        \"\"\"Obtain semantic segmentation map for panoptic segmentation.\n\n        Args:\n            seg_preds (torch.Tensor): Segmentation prediction\n            ori_shape (tuple[int]): Input image shape with padding.\n            img_shape_withoutpad (tuple[int]): Original image shape before\n                without padding.\n        Returns:\n            list[list[np.ndarray]]: The decoded segmentation masks.\n                The first dimension is the number of classes.\n                The second dimension is the number of masks of a similar class.\n        \"\"\"\n        # only surport 1 batch\n        seg_preds = F.interpolate(\n            seg_preds,\n            scale_factor=self.pred_stride,\n            mode='bilinear',\n            align_corners=False)\n        seg_preds = seg_preds[:, :, 0:img_shape_withoutpad[0],\n                              0:img_shape_withoutpad[1]]\n        # seg_masks = F.softmax(seg_preds, 1)\n        # seg_masks = F.interpolate(\n        #     seg_masks,\n        #     size=ori_shape[0:2],\n        #     mode='bilinear',\n        #     align_corners=False)\n        seg_results = F.interpolate(\n            seg_preds,\n            size=ori_shape[0:2],\n            mode='bilinear',\n            align_corners=False)\n        return seg_results\n"
  },
  {
    "path": "external/semkitti_dvps.py",
    "content": "import os\nfrom typing import Dict, List\n\nimport copy\n\nimport mmcv\nimport numpy as np\nimport random\nimport torch\n\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.pipelines import Compose\n\n\nclass SeqObj:\n    # This divisor is orthogonal with panoptic class-instance divisor.\n    DIVISOR = 1000000\n\n    def __init__(self, the_dict: Dict):\n        self.dict = the_dict\n        assert 'seq_id' in self.dict and 'img_id' in self.dict\n\n    def __hash__(self):\n        return self.dict['seq_id'] * self.DIVISOR + self.dict['img_id']\n\n    def __eq__(self, other):\n        return self.dict['seq_id'] == other.dict['seq_id'] and self.dict['img_id'] == other.dict['img_id']\n\n    def __getitem__(self, attr):\n        return self.dict[attr]\n\n\n@DATASETS.register_module()\nclass KITTIDVPSDataset:\n    CLASSES = (\n        'car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', 'motorcyclist'\n    )\n\n    def __init__(self,\n                 pipeline=None,\n                 data_root=None,\n                 test_mode=False,\n                 split='train',\n                 ref_seq_index: List[int] = None,\n                 is_instance_only: bool = True,\n                 ):\n        assert data_root is not None\n        data_root = os.path.expanduser(data_root)\n        video_seq_dir = os.path.join(data_root, 'video_sequence', split)\n        assert os.path.exists(video_seq_dir)\n        assert 'leftImg8bit' not in video_seq_dir\n\n        self.num_thing_classes = 8\n        self.num_stuff_classes = 11\n        self.thing_before_stuff = True\n\n        # ref_seq_index is None means no ref img\n        if ref_seq_index is None:\n            ref_seq_index = []\n\n        filenames = list(map(lambda x: str(x), os.listdir(video_seq_dir)))\n        depth_names = sorted(list(filter(lambda x: 'depth' in x, filenames)))\n        # No depth annotation\n        if not depth_names:\n            depth_names = sorted(list(filter(lambda x: 'leftImg8bit' in x, filenames)))\n\n        images = []\n        for item in depth_names:\n            seq_id, img_id, _ = item.split(sep=\"_\", maxsplit=2)\n            item_full = os.path.join(video_seq_dir, item)\n            images.append(SeqObj({\n                'seq_id': int(seq_id),\n                'img_id': int(img_id),\n                'img': os.path.join(video_seq_dir, \"{}_{}_{}.png\".format(seq_id, img_id, 'leftImg8bit')),\n                'depth': item_full,\n                'ann_class': os.path.join(video_seq_dir, \"{}_{}_{}.png\".format(seq_id, img_id, 'gtFine_class')),\n                'ann_inst': os.path.join(video_seq_dir, \"{}_{}_{}.png\".format(seq_id, img_id, 'gtFine_instance')),\n                # This should be modified carefully for each dataset. Usually 255.\n                'no_obj_class': 255\n            }))\n            assert os.path.exists(images[-1]['img'])\n            if not test_mode:\n                assert os.path.exists(images[-1]['depth'])\n                assert os.path.exists(images[-1]['ann_class'])\n                assert os.path.exists(images[-1]['ann_inst'])\n\n        reference_images = {hash(image): image for image in images}\n        sequences = []\n        for img_cur in images:\n            is_seq = True\n            seq_now = [img_cur.dict]\n            if ref_seq_index:\n                for index in random.choices(ref_seq_index, k=1):\n                    query_obj = SeqObj({\n                        'seq_id': img_cur.dict['seq_id'],\n                        'img_id': img_cur.dict['img_id'] + index\n                    })\n                    if hash(query_obj) in reference_images:\n                        seq_now.append(reference_images[hash(query_obj)].dict)\n                    else:\n                        is_seq = False\n                        break\n            if is_seq:\n                sequences.append(seq_now)\n\n        self.sequences = sequences\n        self.ref_seq_index = ref_seq_index\n\n        # mmdet\n        self.pipeline = Compose(pipeline)\n        self.test_mode = test_mode\n\n        # misc\n        self.flag = self._set_groups()\n        self.is_instance_only = is_instance_only\n\n        # For evaluation\n        self.max_ins = 1000\n        self.no_obj_id = 255\n\n    def pre_pipelines(self, results):\n        for _results in results:\n            _results['img_info'] = []\n            _results['thing_lower'] = 0 if self.thing_before_stuff else self.num_stuff_classes\n            _results['thing_upper'] = self.num_thing_classes \\\n                if self.thing_before_stuff else self.num_stuff_classes + self.num_thing_classes\n            _results['is_instance_only'] = self.is_instance_only\n            _results['ori_filename'] = os.path.basename(_results['img'])\n\n    def prepare_train_img(self, idx):\n        \"\"\"Get training data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training data and annotation after pipeline with new keys \\\n                introduced by pipeline.\n        \"\"\"\n        results = copy.deepcopy(self.sequences[idx])\n        self.pre_pipelines(results)\n        return self.pipeline(results)\n\n    def prepare_test_img(self, idx):\n        results = copy.deepcopy(self.sequences[idx])\n        self.pre_pipelines(results)\n        # During test time, one image inference does not requires seq\n        if not self.ref_seq_index:\n            results = results[0]\n        return self.pipeline(results)\n\n    def _rand_another(self, idx):\n        \"\"\"Get another random index from the same group as the given index.\"\"\"\n        pool = np.where(self.flag == self.flag[idx])[0]\n        return np.random.choice(pool)\n\n    # Copy and Modify from mmdet\n    def __getitem__(self, idx):\n        \"\"\"Get training/test data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training/test data (with annotation if `test_mode` is set \\\n                True).\n        \"\"\"\n\n        if self.test_mode:\n            return self.prepare_test_img(idx)\n        else:\n            while True:\n                cur_data = self.prepare_train_img(idx)\n                if cur_data is None:\n                    idx = self._rand_another(idx)\n                    continue\n                return cur_data\n\n    def __len__(self):\n        \"\"\"Total number of samples of data.\"\"\"\n        return len(self.sequences)\n\n    def _set_groups(self):\n        return np.zeros((len(self)), dtype=np.int64)\n\n    # The evaluate func\n    def evaluate(\n            self,\n            results,\n            **kwargs\n    ):\n        thing_lower = 0 if self.thing_before_stuff else self.num_stuff_classes\n        thing_upper = self.num_thing_classes \\\n            if self.thing_before_stuff else self.num_stuff_classes + self.num_thing_classes\n        pred_results_handled = []\n        pred_depth = []\n        pred_depth_final = []\n        for item in results:\n            bbox_results, mask_results, seg_results, depth, depth_final = item\n            pred_depth.append(depth)\n            pred_depth_final.append(depth_final)\n            # in seg_info id starts from 1\n            inst_map, seg_info = seg_results\n            cat_map = np.zeros_like(inst_map) + self.num_thing_classes + self.num_stuff_classes\n            for instance in seg_info:\n                cat_cur = instance['category_id']\n                if instance['isthing']:\n                    cat_cur += thing_lower\n                else:\n                    if self.thing_before_stuff:\n                        cat_cur = cat_cur - 1 + thing_upper\n                    else:\n                        # stuff starts from 1 in the model\n                        cat_cur -= 1\n                assert cat_cur < self.num_thing_classes + self.num_stuff_classes\n                cat_map[inst_map == instance['id']] = cat_cur\n                if not instance['isthing']:\n                    inst_map[inst_map == instance['id']] = 0\n            pred_results_handled.append(cat_map.astype(np.int32) * 10000 + inst_map.astype(np.int32))\n\n        gt_panseg = []\n        gt_depth = []\n        for item in self.sequences:\n            # Only for single\n            item = item[0]\n            # Only for single\n            cat_id = mmcv.imread(item['ann_class'], flag='unchanged').astype(np.int32)\n            inst_id = mmcv.imread(item['ann_inst'], flag='unchanged').astype(np.int32)\n            ps_id = cat_id * 10000 + inst_id\n            gt_panseg.append(ps_id)\n            gt_depth_cur = mmcv.imread(item['depth'], flag='unchanged').astype(np.float32) / 256.\n            gt_depth.append(gt_depth_cur)\n\n        vpq_results = []\n        for pred, gt in zip(pred_results_handled, gt_panseg):\n            vpq_result = vpq_eval([pred, gt])\n            vpq_results.append(vpq_result)\n\n        iou_per_class = np.stack([result[0] for result in vpq_results]).sum(axis=0)[\n                        :self.num_thing_classes + self.num_stuff_classes]\n        tp_per_class = np.stack([result[1] for result in vpq_results]).sum(axis=0)[\n                       :self.num_thing_classes + self.num_stuff_classes]\n        fn_per_class = np.stack([result[2] for result in vpq_results]).sum(axis=0)[\n                       :self.num_thing_classes + self.num_stuff_classes]\n        fp_per_class = np.stack([result[3] for result in vpq_results]).sum(axis=0)[\n                       :self.num_thing_classes + self.num_stuff_classes]\n\n        abs_rels = []\n        abs_rel_finals = []\n        for pred, pred_final, gt in zip(pred_depth, pred_depth_final, gt_depth):\n            depth_mask = gt > 0.\n            abs_rel_normal = np.mean(\n                np.abs(\n                    pred[depth_mask] -\n                    gt[depth_mask]) /\n                gt[depth_mask])\n            abs_rel_final = np.mean(\n                np.abs(\n                    pred_final[depth_mask] -\n                    gt[depth_mask]) /\n                gt[depth_mask])\n            abs_rels.append(abs_rel_normal)\n            abs_rel_finals.append(abs_rel_final)\n        abs_rel = np.stack(abs_rels).mean(axis=0)\n        abs_rel_final = np.stack(abs_rel_finals).mean(axis=0)\n\n        # calculate the PQs\n        epsilon = 0.\n        sq = iou_per_class / (tp_per_class + epsilon)\n        rq = tp_per_class / (tp_per_class + 0.5 *\n                             fn_per_class + 0.5 * fp_per_class + epsilon)\n        print(\"tp per class\")\n        print(tp_per_class)\n        print(\"fp per class\")\n        print(fp_per_class)\n        print(\"fn per class\")\n        print(fn_per_class)\n\n        pq = sq * rq\n        print(\"PQ\")\n        print(pq[:thing_upper])\n        print(pq[thing_upper:])\n        print(\"SQ\")\n        print(sq)\n        print(\"RQ\")\n        print(rq)\n        stuff_pq = pq[:thing_upper]\n        things_pq = pq[thing_upper:]\n\n        return {\n            \"abs_rel\": abs_rel,\n            \"abs_rel_final\": abs_rel_final,\n            \"PQ\": np.nan_to_num(pq).mean() * 100,\n            \"Stuff PQ\": np.nan_to_num(stuff_pq).mean() * 100,\n            \"Things PQ\": np.nan_to_num(things_pq).mean() * 100,\n        }\n\n\ndef vpq_eval(element):\n    import six\n    pred_ids, gt_ids = element\n    max_ins = 10000\n    ign_id = 255\n    offset = 2 ** 30\n    num_cat = 19 + 1\n\n    iou_per_class = np.zeros(num_cat, dtype=np.float64)\n    tp_per_class = np.zeros(num_cat, dtype=np.float64)\n    fn_per_class = np.zeros(num_cat, dtype=np.float64)\n    fp_per_class = np.zeros(num_cat, dtype=np.float64)\n\n    def _ids_to_counts(id_array):\n        ids, counts = np.unique(id_array, return_counts=True)\n        return dict(six.moves.zip(ids, counts))\n\n    pred_areas = _ids_to_counts(pred_ids)\n    gt_areas = _ids_to_counts(gt_ids)\n\n    void_id = ign_id * max_ins\n    ign_ids = {\n        gt_id for gt_id in six.iterkeys(gt_areas)\n        if (gt_id // max_ins) == ign_id\n    }\n\n    int_ids = gt_ids.astype(np.int64) * offset + pred_ids.astype(np.int64)\n    int_areas = _ids_to_counts(int_ids)\n\n    def prediction_void_overlap(pred_id):\n        void_int_id = void_id * offset + pred_id\n        return int_areas.get(void_int_id, 0)\n\n    def prediction_ignored_overlap(pred_id):\n        total_ignored_overlap = 0\n        for _ign_id in ign_ids:\n            int_id = _ign_id * offset + pred_id\n            total_ignored_overlap += int_areas.get(int_id, 0)\n        return total_ignored_overlap\n\n    gt_matched = set()\n    pred_matched = set()\n\n    for int_id, int_area in six.iteritems(int_areas):\n        gt_id = int(int_id // offset)\n        gt_cat = int(gt_id // max_ins)\n        pred_id = int(int_id % offset)\n        pred_cat = int(pred_id // max_ins)\n        if gt_cat != pred_cat:\n            continue\n        union = (\n                gt_areas[gt_id] + pred_areas[pred_id] - int_area -\n                prediction_void_overlap(pred_id)\n        )\n        iou = int_area / union\n        if iou > 0.5:\n            tp_per_class[gt_cat] += 1\n            iou_per_class[gt_cat] += iou\n            gt_matched.add(gt_id)\n            pred_matched.add(pred_id)\n\n    for gt_id in six.iterkeys(gt_areas):\n        if gt_id in gt_matched:\n            continue\n        cat_id = gt_id // max_ins\n        if cat_id == ign_id:\n            continue\n        fn_per_class[cat_id] += 1\n\n    for pred_id in six.iterkeys(pred_areas):\n        if pred_id in pred_matched:\n            continue\n        if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5:\n            continue\n        cat = pred_id // max_ins\n        fp_per_class[cat] += 1\n\n    return iou_per_class, tp_per_class, fn_per_class, fp_per_class\n\n\nif __name__ == '__main__':\n    import dataset.dvps_pipelines.loading\n    import dataset.dvps_pipelines.transforms\n    import dataset.pipelines.formatting\n\n    img_norm_cfg = dict(\n        mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n    data = KITTIDVPSDataset(\n        pipeline=[\n            dict(type='LoadMultiImagesDirect'),\n            dict(type='LoadMultiAnnotationsDirect', with_depth=True, divisor=0),\n            dict(type='SeqResizeWithDepth', img_scale=(1024, 2048), ratio_range=[1.0, 2.0], keep_ratio=True),\n            dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n            dict(type='SeqRandomCropWithDepth', crop_size=(1024, 2048), share_params=True),\n            dict(type='SeqNormalizeWithDepth', **img_norm_cfg),\n            dict(type='SeqPadWithDepth', size_divisor=32),\n            dict(\n                type='VideoCollect',\n                keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg', 'gt_depth', 'gt_instance_ids']),\n            dict(type='ConcatVideoReferences'),\n            dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n        ],\n        data_root=os.path.expanduser('~/datasets/kitti-dvps'),\n        split='val',\n        ref_seq_index=[-1, 1]\n    )\n    np.set_string_function(lambda x: '<{} ; {}>'.format(x.shape, x.dtype))\n    torch.set_printoptions(profile='short')\n    for item in data:\n        print(item)\n"
  },
  {
    "path": "external/test.py",
    "content": "import os.path as osp\nimport time\n\nimport mmcv\nimport torch\nfrom mmcv.image import tensor2imgs\nfrom mmcv.runner import get_dist_info\nfrom mmdet.apis.test import collect_results_cpu, collect_results_gpu\nfrom mmdet.core import encode_mask_results\nfrom .utils import encode_panoptic\n\n\ndef single_gpu_test(model,\n                    data_loader,\n                    show=False,\n                    out_dir=None,\n                    show_score_thr=0.3):\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            result = model(return_loss=False, rescale=True, **data)\n\n        batch_size = len(result)\n        if show or out_dir:\n            if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):\n                img_tensor = data['img'][0]\n            else:\n                img_tensor = data['img'][0].data[0]\n            img_metas = data['img_metas'][0].data[0]\n            imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])\n            assert len(imgs) == len(img_metas)\n\n            for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):\n                h, w, _ = img_meta['img_shape']\n                img_show = img[:h, :w, :]\n\n                ori_h, ori_w = img_meta['ori_shape'][:-1]\n                img_show = mmcv.imresize(img_show, (ori_w, ori_h))\n\n                if out_dir:\n                    out_file = osp.join(out_dir, img_meta['ori_filename'])\n                else:\n                    out_file = None\n\n                model.module.show_result(\n                    img_show,\n                    result[i],\n                    show=show,\n                    out_file=out_file,\n                    score_thr=show_score_thr)\n\n        # encode mask results\n        if isinstance(result[0], tuple):\n            if len(result[0]) == 2:\n                result = [(bbox_results, encode_mask_results(mask_results))\n                          for bbox_results, mask_results in result]\n            # Supporting depth here\n            elif len(result[0]) == 5:\n                result = [(bbox_results, mask_results,\n                           seg_results, depth, depth_final)\n                          for bbox_results, mask_results, seg_results, depth, depth_final in result\n                          ]\n            else:\n                result = [(bbox_results, encode_mask_results(mask_results),\n                           encode_panoptic(seg_results))\n                          for bbox_results, mask_results, seg_results in result\n                          ]\n        results.extend(result)\n\n        for _ in range(batch_size):\n            prog_bar.update()\n    return results\n\n\ndef multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):\n    \"\"\"Test model with multiple gpus.\n\n    This method tests model with multiple gpus and collects the results\n    under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'\n    it encodes results to gpu tensors and use gpu communication for results\n    collection. On cpu mode it saves the results on different gpus to 'tmpdir'\n    and collects them by the rank 0 worker.\n\n    Args:\n        model (nn.Module): Model to be tested.\n        data_loader (nn.Dataloader): Pytorch data loader.\n        tmpdir (str): Path of directory to save the temporary results from\n            different gpus under cpu mode.\n        gpu_collect (bool): Option to use either gpu or cpu to collect results.\n\n    Returns:\n        list: The prediction results.\n    \"\"\"\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    rank, world_size = get_dist_info()\n    if rank == 0:\n        prog_bar = mmcv.ProgressBar(len(dataset))\n    time.sleep(2)  # This line can prevent deadlock problem in some cases.\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            result = model(return_loss=False, rescale=True, **data)\n            # encode mask results\n            if isinstance(result[0], tuple):\n                if len(result[0]) == 2:\n                    result = [(bbox_results, encode_mask_results(mask_results))\n                              for bbox_results, mask_results in result]\n                # Supporting depth here\n                elif len(result[0]) == 5:\n                    result = [(bbox_results, mask_results,\n                               seg_results, depth, depth_final)\n                              for bbox_results, mask_results, seg_results, depth, depth_final in result\n                              ]\n                else:\n                    result = [\n                        (bbox_results, encode_mask_results(mask_results),\n                         encode_panoptic(seg_results))\n                        for bbox_results, mask_results, seg_results in result\n                    ]\n        results.extend(result)\n\n        if rank == 0:\n            batch_size = len(result)\n            for _ in range(batch_size * world_size):\n                prog_bar.update()\n\n    # collect results from all ranks\n    if gpu_collect:\n        results = collect_results_gpu(results, len(dataset))\n    else:\n        results = collect_results_cpu(results, len(dataset), tmpdir)\n    return results\n"
  },
  {
    "path": "external/train.py",
    "content": "import warnings\n\nimport torch\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,\n                         Fp16OptimizerHook, OptimizerHook, build_optimizer,\n                         build_runner)\nfrom mmcv.utils import build_from_cfg\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.utils import get_root_logger\n\nfrom external.evalhooks import DistEvalHook, EvalHook\n\n\ndef train_detector(model,\n                   dataset,\n                   cfg,\n                   distributed=False,\n                   validate=False,\n                   timestamp=None,\n                   meta=None):\n    logger = get_root_logger(cfg.log_level)\n\n    # prepare data loaders\n    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]\n    if 'imgs_per_gpu' in cfg.data:\n        logger.warning('\"imgs_per_gpu\" is deprecated in MMDet V2.0. '\n                       'Please use \"samples_per_gpu\" instead')\n        if 'samples_per_gpu' in cfg.data:\n            logger.warning(\n                f'Got \"imgs_per_gpu\"={cfg.data.imgs_per_gpu} and '\n                f'\"samples_per_gpu\"={cfg.data.samples_per_gpu}, \"imgs_per_gpu\"'\n                f'={cfg.data.imgs_per_gpu} is used in this experiments')\n        else:\n            logger.warning(\n                'Automatically set \"samples_per_gpu\"=\"imgs_per_gpu\"='\n                f'{cfg.data.imgs_per_gpu} in this experiments')\n        cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu\n\n    data_loaders = [\n        build_dataloader(\n            ds,\n            cfg.data.samples_per_gpu,\n            cfg.data.workers_per_gpu,\n            # cfg.gpus will be ignored if distributed\n            len(cfg.gpu_ids),\n            dist=distributed,\n            seed=cfg.seed) for ds in dataset\n    ]\n\n    # put model on gpus\n    if distributed:\n        find_unused_parameters = cfg.get('find_unused_parameters', False)\n        # Sets the `find_unused_parameters` parameter in\n        # torch.nn.parallel.DistributedDataParallel\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False,\n            find_unused_parameters=find_unused_parameters)\n    else:\n        model = MMDataParallel(\n            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)\n\n    # build runner\n    optimizer = build_optimizer(model, cfg.optimizer)\n\n    if 'runner' not in cfg:\n        cfg.runner = {\n            'type': 'EpochBasedRunner',\n            'max_epochs': cfg.total_epochs\n        }\n        warnings.warn(\n            'config is now expected to have a `runner` section, '\n            'please set `runner` in your config.', UserWarning)\n    else:\n        if 'total_epochs' in cfg:\n            assert cfg.total_epochs == cfg.runner.max_epochs\n\n    runner = build_runner(\n        cfg.runner,\n        default_args=dict(\n            model=model,\n            optimizer=optimizer,\n            work_dir=cfg.work_dir,\n            logger=logger,\n            meta=meta))\n\n    # an ugly workaround to make .log and .log.json filenames the same\n    runner.timestamp = timestamp\n\n    # fp16 setting\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        optimizer_config = Fp16OptimizerHook(\n            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)\n    elif distributed and 'type' not in cfg.optimizer_config:\n        optimizer_config = OptimizerHook(**cfg.optimizer_config)\n    else:\n        optimizer_config = cfg.optimizer_config\n\n    # register hooks\n    runner.register_training_hooks(cfg.lr_config, optimizer_config,\n                                   cfg.checkpoint_config, cfg.log_config,\n                                   cfg.get('momentum_config', None))\n    if distributed:\n        if isinstance(runner, EpochBasedRunner):\n            runner.register_hook(DistSamplerSeedHook())\n\n    # register eval hooks\n    if validate:\n        # Support batch_size > 1 in validation\n        val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)\n        if val_samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.val.pipeline = replace_ImageToTensor(\n                cfg.data.val.pipeline)\n        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))\n        val_dataloader = build_dataloader(\n            val_dataset,\n            samples_per_gpu=val_samples_per_gpu,\n            workers_per_gpu=cfg.data.workers_per_gpu,\n            dist=distributed,\n            shuffle=False)\n        eval_cfg = cfg.get('evaluation', {})\n        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'\n        eval_hook = DistEvalHook if distributed else EvalHook\n        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))\n\n    # user-defined hooks\n    if cfg.get('custom_hooks', None):\n        custom_hooks = cfg.custom_hooks\n        assert isinstance(custom_hooks, list), \\\n            f'custom_hooks expect list type, but got {type(custom_hooks)}'\n        for hook_cfg in cfg.custom_hooks:\n            assert isinstance(hook_cfg, dict), \\\n                'Each item in custom_hooks expects dict type, but got ' \\\n                f'{type(hook_cfg)}'\n            hook_cfg = hook_cfg.copy()\n            priority = hook_cfg.pop('priority', 'NORMAL')\n            hook = build_from_cfg(hook_cfg, HOOKS)\n            runner.register_hook(hook, priority=priority)\n\n    if cfg.resume_from:\n        runner.resume(cfg.resume_from)\n    elif cfg.load_from:\n        runner.load_checkpoint(cfg.load_from)\n    runner.run(data_loaders, cfg.workflow)\n"
  },
  {
    "path": "external/utils.py",
    "content": "import io\n\nfrom panopticapi.utils import id2rgb\nfrom PIL import Image\n\n\ndef encode_panoptic(panoptic_results):\n    panoptic_img, segments_info = panoptic_results\n    with io.BytesIO() as out:\n        Image.fromarray(id2rgb(panoptic_img)).save(out, format='PNG')\n        return out.getvalue(), segments_info\n"
  },
  {
    "path": "external/vipseg_dvps.py",
    "content": "import os\nimport random\nfrom typing import Dict, List\n\nimport copy\n\nimport mmcv\nimport numpy as np\nimport torch\n\nfrom mmdet.datasets.builder import DATASETS\nfrom mmdet.datasets.pipelines import Compose\nfrom mmdet.utils import get_root_logger\n\nCLASSES = [\n    {\"id\": 0, \"name\": \"wall\", \"isthing\": 0, \"color\": [120, 120, 120]},\n    {\"id\": 1, \"name\": \"ceiling\", \"isthing\": 0, \"color\": [180, 120, 120]},\n    {\"id\": 2, \"name\": \"door\", \"isthing\": 1, \"color\": [6, 230, 230]},\n    {\"id\": 3, \"name\": \"stair\", \"isthing\": 0, \"color\": [80, 50, 50]},\n    {\"id\": 4, \"name\": \"ladder\", \"isthing\": 1, \"color\": [4, 200, 3]},\n    {\"id\": 5, \"name\": \"escalator\", \"isthing\": 0, \"color\": [120, 120, 80]},\n    {\"id\": 6, \"name\": \"Playground_slide\", \"isthing\": 0, \"color\": [140, 140, 140]},\n    {\"id\": 7, \"name\": \"handrail_or_fence\", \"isthing\": 0, \"color\": [204, 5, 255]},\n    {\"id\": 8, \"name\": \"window\", \"isthing\": 1, \"color\": [230, 230, 230]},\n    {\"id\": 9, \"name\": \"rail\", \"isthing\": 0, \"color\": [4, 250, 7]},\n    {\"id\": 10, \"name\": \"goal\", \"isthing\": 1, \"color\": [224, 5, 255]},\n    {\"id\": 11, \"name\": \"pillar\", \"isthing\": 0, \"color\": [235, 255, 7]},\n    {\"id\": 12, \"name\": \"pole\", \"isthing\": 0, \"color\": [150, 5, 61]},\n    {\"id\": 13, \"name\": \"floor\", \"isthing\": 0, \"color\": [120, 120, 70]},\n    {\"id\": 14, \"name\": \"ground\", \"isthing\": 0, \"color\": [8, 255, 51]},\n    {\"id\": 15, \"name\": \"grass\", \"isthing\": 0, \"color\": [255, 6, 82]},\n    {\"id\": 16, \"name\": \"sand\", \"isthing\": 0, \"color\": [143, 255, 140]},\n    {\"id\": 17, \"name\": \"athletic_field\", \"isthing\": 0, \"color\": [204, 255, 4]},\n    {\"id\": 18, \"name\": \"road\", \"isthing\": 0, \"color\": [255, 51, 7]},\n    {\"id\": 19, \"name\": \"path\", \"isthing\": 0, \"color\": [204, 70, 3]},\n    {\"id\": 20, \"name\": \"crosswalk\", \"isthing\": 0, \"color\": [0, 102, 200]},\n    {\"id\": 21, \"name\": \"building\", \"isthing\": 0, \"color\": [61, 230, 250]},\n    {\"id\": 22, \"name\": \"house\", \"isthing\": 0, \"color\": [255, 6, 51]},\n    {\"id\": 23, \"name\": \"bridge\", \"isthing\": 0, \"color\": [11, 102, 255]},\n    {\"id\": 24, \"name\": \"tower\", \"isthing\": 0, \"color\": [255, 7, 71]},\n    {\"id\": 25, \"name\": \"windmill\", \"isthing\": 0, \"color\": [255, 9, 224]},\n    {\"id\": 26, \"name\": \"well_or_well_lid\", \"isthing\": 0, \"color\": [9, 7, 230]},\n    {\"id\": 27, \"name\": \"other_construction\", \"isthing\": 0, \"color\": [220, 220, 220]},\n    {\"id\": 28, \"name\": \"sky\", \"isthing\": 0, \"color\": [255, 9, 92]},\n    {\"id\": 29, \"name\": \"mountain\", \"isthing\": 0, \"color\": [112, 9, 255]},\n    {\"id\": 30, \"name\": \"stone\", \"isthing\": 0, \"color\": [8, 255, 214]},\n    {\"id\": 31, \"name\": \"wood\", \"isthing\": 0, \"color\": [7, 255, 224]},\n    {\"id\": 32, \"name\": \"ice\", \"isthing\": 0, \"color\": [255, 184, 6]},\n    {\"id\": 33, \"name\": \"snowfield\", \"isthing\": 0, \"color\": [10, 255, 71]},\n    {\"id\": 34, \"name\": \"grandstand\", \"isthing\": 0, \"color\": [255, 41, 10]},\n    {\"id\": 35, \"name\": \"sea\", \"isthing\": 0, \"color\": [7, 255, 255]},\n    {\"id\": 36, \"name\": \"river\", \"isthing\": 0, \"color\": [224, 255, 8]},\n    {\"id\": 37, \"name\": \"lake\", \"isthing\": 0, \"color\": [102, 8, 255]},\n    {\"id\": 38, \"name\": \"waterfall\", \"isthing\": 0, \"color\": [255, 61, 6]},\n    {\"id\": 39, \"name\": \"water\", \"isthing\": 0, \"color\": [255, 194, 7]},\n    {\"id\": 40, \"name\": \"billboard_or_Bulletin_Board\", \"isthing\": 0, \"color\": [255, 122, 8]},\n    {\"id\": 41, \"name\": \"sculpture\", \"isthing\": 1, \"color\": [0, 255, 20]},\n    {\"id\": 42, \"name\": \"pipeline\", \"isthing\": 0, \"color\": [255, 8, 41]},\n    {\"id\": 43, \"name\": \"flag\", \"isthing\": 1, \"color\": [255, 5, 153]},\n    {\"id\": 44, \"name\": \"parasol_or_umbrella\", \"isthing\": 1, \"color\": [6, 51, 255]},\n    {\"id\": 45, \"name\": \"cushion_or_carpet\", \"isthing\": 0, \"color\": [235, 12, 255]},\n    {\"id\": 46, \"name\": \"tent\", \"isthing\": 1, \"color\": [160, 150, 20]},\n    {\"id\": 47, \"name\": \"roadblock\", \"isthing\": 1, \"color\": [0, 163, 255]},\n    {\"id\": 48, \"name\": \"car\", \"isthing\": 1, \"color\": [140, 140, 140]},\n    {\"id\": 49, \"name\": \"bus\", \"isthing\": 1, \"color\": [250, 10, 15]},\n    {\"id\": 50, \"name\": \"truck\", \"isthing\": 1, \"color\": [20, 255, 0]},\n    {\"id\": 51, \"name\": \"bicycle\", \"isthing\": 1, \"color\": [31, 255, 0]},\n    {\"id\": 52, \"name\": \"motorcycle\", \"isthing\": 1, \"color\": [255, 31, 0]},\n    {\"id\": 53, \"name\": \"wheeled_machine\", \"isthing\": 0, \"color\": [255, 224, 0]},\n    {\"id\": 54, \"name\": \"ship_or_boat\", \"isthing\": 1, \"color\": [153, 255, 0]},\n    {\"id\": 55, \"name\": \"raft\", \"isthing\": 1, \"color\": [0, 0, 255]},\n    {\"id\": 56, \"name\": \"airplane\", \"isthing\": 1, \"color\": [255, 71, 0]},\n    {\"id\": 57, \"name\": \"tyre\", \"isthing\": 0, \"color\": [0, 235, 255]},\n    {\"id\": 58, \"name\": \"traffic_light\", \"isthing\": 0, \"color\": [0, 173, 255]},\n    {\"id\": 59, \"name\": \"lamp\", \"isthing\": 0, \"color\": [31, 0, 255]},\n    {\"id\": 60, \"name\": \"person\", \"isthing\": 1, \"color\": [11, 200, 200]},\n    {\"id\": 61, \"name\": \"cat\", \"isthing\": 1, \"color\": [255, 82, 0]},\n    {\"id\": 62, \"name\": \"dog\", \"isthing\": 1, \"color\": [0, 255, 245]},\n    {\"id\": 63, \"name\": \"horse\", \"isthing\": 1, \"color\": [0, 61, 255]},\n    {\"id\": 64, \"name\": \"cattle\", \"isthing\": 1, \"color\": [0, 255, 112]},\n    {\"id\": 65, \"name\": \"other_animal\", \"isthing\": 1, \"color\": [0, 255, 133]},\n    {\"id\": 66, \"name\": \"tree\", \"isthing\": 0, \"color\": [255, 0, 0]},\n    {\"id\": 67, \"name\": \"flower\", \"isthing\": 0, \"color\": [255, 163, 0]},\n    {\"id\": 68, \"name\": \"other_plant\", \"isthing\": 0, \"color\": [255, 102, 0]},\n    {\"id\": 69, \"name\": \"toy\", \"isthing\": 0, \"color\": [194, 255, 0]},\n    {\"id\": 70, \"name\": \"ball_net\", \"isthing\": 0, \"color\": [0, 143, 255]},\n    {\"id\": 71, \"name\": \"backboard\", \"isthing\": 0, \"color\": [51, 255, 0]},\n    {\"id\": 72, \"name\": \"skateboard\", \"isthing\": 1, \"color\": [0, 82, 255]},\n    {\"id\": 73, \"name\": \"bat\", \"isthing\": 0, \"color\": [0, 255, 41]},\n    {\"id\": 74, \"name\": \"ball\", \"isthing\": 1, \"color\": [0, 255, 173]},\n    {\"id\": 75, \"name\": \"cupboard_or_showcase_or_storage_rack\", \"isthing\": 0, \"color\": [10, 0, 255]},\n    {\"id\": 76, \"name\": \"box\", \"isthing\": 1, \"color\": [173, 255, 0]},\n    {\"id\": 77, \"name\": \"traveling_case_or_trolley_case\", \"isthing\": 1, \"color\": [0, 255, 153]},\n    {\"id\": 78, \"name\": \"basket\", \"isthing\": 1, \"color\": [255, 92, 0]},\n    {\"id\": 79, \"name\": \"bag_or_package\", \"isthing\": 1, \"color\": [255, 0, 255]},\n    {\"id\": 80, \"name\": \"trash_can\", \"isthing\": 0, \"color\": [255, 0, 245]},\n    {\"id\": 81, \"name\": \"cage\", \"isthing\": 0, \"color\": [255, 0, 102]},\n    {\"id\": 82, \"name\": \"plate\", \"isthing\": 1, \"color\": [255, 173, 0]},\n    {\"id\": 83, \"name\": \"tub_or_bowl_or_pot\", \"isthing\": 1, \"color\": [255, 0, 20]},\n    {\"id\": 84, \"name\": \"bottle_or_cup\", \"isthing\": 1, \"color\": [255, 184, 184]},\n    {\"id\": 85, \"name\": \"barrel\", \"isthing\": 1, \"color\": [0, 31, 255]},\n    {\"id\": 86, \"name\": \"fishbowl\", \"isthing\": 1, \"color\": [0, 255, 61]},\n    {\"id\": 87, \"name\": \"bed\", \"isthing\": 1, \"color\": [0, 71, 255]},\n    {\"id\": 88, \"name\": \"pillow\", \"isthing\": 1, \"color\": [255, 0, 204]},\n    {\"id\": 89, \"name\": \"table_or_desk\", \"isthing\": 1, \"color\": [0, 255, 194]},\n    {\"id\": 90, \"name\": \"chair_or_seat\", \"isthing\": 1, \"color\": [0, 255, 82]},\n    {\"id\": 91, \"name\": \"bench\", \"isthing\": 1, \"color\": [0, 10, 255]},\n    {\"id\": 92, \"name\": \"sofa\", \"isthing\": 1, \"color\": [0, 112, 255]},\n    {\"id\": 93, \"name\": \"shelf\", \"isthing\": 0, \"color\": [51, 0, 255]},\n    {\"id\": 94, \"name\": \"bathtub\", \"isthing\": 0, \"color\": [0, 194, 255]},\n    {\"id\": 95, \"name\": \"gun\", \"isthing\": 1, \"color\": [0, 122, 255]},\n    {\"id\": 96, \"name\": \"commode\", \"isthing\": 1, \"color\": [0, 255, 163]},\n    {\"id\": 97, \"name\": \"roaster\", \"isthing\": 1, \"color\": [255, 153, 0]},\n    {\"id\": 98, \"name\": \"other_machine\", \"isthing\": 0, \"color\": [0, 255, 10]},\n    {\"id\": 99, \"name\": \"refrigerator\", \"isthing\": 1, \"color\": [255, 112, 0]},\n    {\"id\": 100, \"name\": \"washing_machine\", \"isthing\": 1, \"color\": [143, 255, 0]},\n    {\"id\": 101, \"name\": \"Microwave_oven\", \"isthing\": 1, \"color\": [82, 0, 255]},\n    {\"id\": 102, \"name\": \"fan\", \"isthing\": 1, \"color\": [163, 255, 0]},\n    {\"id\": 103, \"name\": \"curtain\", \"isthing\": 0, \"color\": [255, 235, 0]},\n    {\"id\": 104, \"name\": \"textiles\", \"isthing\": 0, \"color\": [8, 184, 170]},\n    {\"id\": 105, \"name\": \"clothes\", \"isthing\": 0, \"color\": [133, 0, 255]},\n    {\"id\": 106, \"name\": \"painting_or_poster\", \"isthing\": 1, \"color\": [0, 255, 92]},\n    {\"id\": 107, \"name\": \"mirror\", \"isthing\": 1, \"color\": [184, 0, 255]},\n    {\"id\": 108, \"name\": \"flower_pot_or_vase\", \"isthing\": 1, \"color\": [255, 0, 31]},\n    {\"id\": 109, \"name\": \"clock\", \"isthing\": 1, \"color\": [0, 184, 255]},\n    {\"id\": 110, \"name\": \"book\", \"isthing\": 0, \"color\": [0, 214, 255]},\n    {\"id\": 111, \"name\": \"tool\", \"isthing\": 0, \"color\": [255, 0, 112]},\n    {\"id\": 112, \"name\": \"blackboard\", \"isthing\": 0, \"color\": [92, 255, 0]},\n    {\"id\": 113, \"name\": \"tissue\", \"isthing\": 0, \"color\": [0, 224, 255]},\n    {\"id\": 114, \"name\": \"screen_or_television\", \"isthing\": 1, \"color\": [112, 224, 255]},\n    {\"id\": 115, \"name\": \"computer\", \"isthing\": 1, \"color\": [70, 184, 160]},\n    {\"id\": 116, \"name\": \"printer\", \"isthing\": 1, \"color\": [163, 0, 255]},\n    {\"id\": 117, \"name\": \"Mobile_phone\", \"isthing\": 1, \"color\": [153, 0, 255]},\n    {\"id\": 118, \"name\": \"keyboard\", \"isthing\": 1, \"color\": [71, 255, 0]},\n    {\"id\": 119, \"name\": \"other_electronic_product\", \"isthing\": 0, \"color\": [255, 0, 163]},\n    {\"id\": 120, \"name\": \"fruit\", \"isthing\": 0, \"color\": [255, 204, 0]},\n    {\"id\": 121, \"name\": \"food\", \"isthing\": 0, \"color\": [255, 0, 143]},\n    {\"id\": 122, \"name\": \"instrument\", \"isthing\": 1, \"color\": [0, 255, 235]},\n    {\"id\": 123, \"name\": \"train\", \"isthing\": 1, \"color\": [133, 255, 0]}\n]\n\nCLASSES_THING = [\n    {'id': 2, 'name': 'door', 'isthing': 1, 'color': [6, 230, 230]},\n    {'id': 4, 'name': 'ladder', 'isthing': 1, 'color': [4, 200, 3]},\n    {'id': 8, 'name': 'window', 'isthing': 1, 'color': [230, 230, 230]},\n    {'id': 10, 'name': 'goal', 'isthing': 1, 'color': [224, 5, 255]},\n    {'id': 41, 'name': 'sculpture', 'isthing': 1, 'color': [0, 255, 20]},\n    {'id': 43, 'name': 'flag', 'isthing': 1, 'color': [255, 5, 153]},\n    {'id': 44, 'name': 'parasol_or_umbrella', 'isthing': 1, 'color': [6, 51, 255]},\n    {'id': 46, 'name': 'tent', 'isthing': 1, 'color': [160, 150, 20]},\n    {'id': 47, 'name': 'roadblock', 'isthing': 1, 'color': [0, 163, 255]},\n    {'id': 48, 'name': 'car', 'isthing': 1, 'color': [140, 140, 140]},\n    {'id': 49, 'name': 'bus', 'isthing': 1, 'color': [250, 10, 15]},\n    {'id': 50, 'name': 'truck', 'isthing': 1, 'color': [20, 255, 0]},\n    {'id': 51, 'name': 'bicycle', 'isthing': 1, 'color': [31, 255, 0]},\n    {'id': 52, 'name': 'motorcycle', 'isthing': 1, 'color': [255, 31, 0]},\n    {'id': 54, 'name': 'ship_or_boat', 'isthing': 1, 'color': [153, 255, 0]},\n    {'id': 55, 'name': 'raft', 'isthing': 1, 'color': [0, 0, 255]},\n    {'id': 56, 'name': 'airplane', 'isthing': 1, 'color': [255, 71, 0]},\n    {'id': 60, 'name': 'person', 'isthing': 1, 'color': [11, 200, 200]},\n    {'id': 61, 'name': 'cat', 'isthing': 1, 'color': [255, 82, 0]},\n    {'id': 62, 'name': 'dog', 'isthing': 1, 'color': [0, 255, 245]},\n    {'id': 63, 'name': 'horse', 'isthing': 1, 'color': [0, 61, 255]},\n    {'id': 64, 'name': 'cattle', 'isthing': 1, 'color': [0, 255, 112]},\n    {'id': 65, 'name': 'other_animal', 'isthing': 1, 'color': [0, 255, 133]},\n    {'id': 72, 'name': 'skateboard', 'isthing': 1, 'color': [0, 82, 255]},\n    {'id': 74, 'name': 'ball', 'isthing': 1, 'color': [0, 255, 173]},\n    {'id': 76, 'name': 'box', 'isthing': 1, 'color': [173, 255, 0]},\n    {'id': 77, 'name': 'traveling_case_or_trolley_case', 'isthing': 1, 'color': [0, 255, 153]},\n    {'id': 78, 'name': 'basket', 'isthing': 1, 'color': [255, 92, 0]},\n    {'id': 79, 'name': 'bag_or_package', 'isthing': 1, 'color': [255, 0, 255]},\n    {'id': 82, 'name': 'plate', 'isthing': 1, 'color': [255, 173, 0]},\n    {'id': 83, 'name': 'tub_or_bowl_or_pot', 'isthing': 1, 'color': [255, 0, 20]},\n    {'id': 84, 'name': 'bottle_or_cup', 'isthing': 1, 'color': [255, 184, 184]},\n    {'id': 85, 'name': 'barrel', 'isthing': 1, 'color': [0, 31, 255]},\n    {'id': 86, 'name': 'fishbowl', 'isthing': 1, 'color': [0, 255, 61]},\n    {'id': 87, 'name': 'bed', 'isthing': 1, 'color': [0, 71, 255]},\n    {'id': 88, 'name': 'pillow', 'isthing': 1, 'color': [255, 0, 204]},\n    {'id': 89, 'name': 'table_or_desk', 'isthing': 1, 'color': [0, 255, 194]},\n    {'id': 90, 'name': 'chair_or_seat', 'isthing': 1, 'color': [0, 255, 82]},\n    {'id': 91, 'name': 'bench', 'isthing': 1, 'color': [0, 10, 255]},\n    {'id': 92, 'name': 'sofa', 'isthing': 1, 'color': [0, 112, 255]},\n    {'id': 95, 'name': 'gun', 'isthing': 1, 'color': [0, 122, 255]},\n    {'id': 96, 'name': 'commode', 'isthing': 1, 'color': [0, 255, 163]},\n    {'id': 97, 'name': 'roaster', 'isthing': 1, 'color': [255, 153, 0]},\n    {'id': 99, 'name': 'refrigerator', 'isthing': 1, 'color': [255, 112, 0]},\n    {'id': 100, 'name': 'washing_machine', 'isthing': 1, 'color': [143, 255, 0]},\n    {'id': 101, 'name': 'Microwave_oven', 'isthing': 1, 'color': [82, 0, 255]},\n    {'id': 102, 'name': 'fan', 'isthing': 1, 'color': [163, 255, 0]},\n    {'id': 106, 'name': 'painting_or_poster', 'isthing': 1, 'color': [0, 255, 92]},\n    {'id': 107, 'name': 'mirror', 'isthing': 1, 'color': [184, 0, 255]},\n    {'id': 108, 'name': 'flower_pot_or_vase', 'isthing': 1, 'color': [255, 0, 31]},\n    {'id': 109, 'name': 'clock', 'isthing': 1, 'color': [0, 184, 255]},\n    {'id': 114, 'name': 'screen_or_television', 'isthing': 1, 'color': [112, 224, 255]},\n    {'id': 115, 'name': 'computer', 'isthing': 1, 'color': [70, 184, 160]},\n    {'id': 116, 'name': 'printer', 'isthing': 1, 'color': [163, 0, 255]},\n    {'id': 117, 'name': 'Mobile_phone', 'isthing': 1, 'color': [153, 0, 255]},\n    {'id': 118, 'name': 'keyboard', 'isthing': 1, 'color': [71, 255, 0]},\n    {'id': 122, 'name': 'instrument', 'isthing': 1, 'color': [0, 255, 235]},\n    {'id': 123, 'name': 'train', 'isthing': 1, 'color': [133, 255, 0]}\n]\n\nCLASSES_STUFF = [\n    {'id': 0, 'name': 'wall', 'isthing': 0, 'color': [120, 120, 120]},\n    {'id': 1, 'name': 'ceiling', 'isthing': 0, 'color': [180, 120, 120]},\n    {'id': 3, 'name': 'stair', 'isthing': 0, 'color': [80, 50, 50]},\n    {'id': 5, 'name': 'escalator', 'isthing': 0, 'color': [120, 120, 80]},\n    {'id': 6, 'name': 'Playground_slide', 'isthing': 0, 'color': [140, 140, 140]},\n    {'id': 7, 'name': 'handrail_or_fence', 'isthing': 0, 'color': [204, 5, 255]},\n    {'id': 9, 'name': 'rail', 'isthing': 0, 'color': [4, 250, 7]},\n    {'id': 11, 'name': 'pillar', 'isthing': 0, 'color': [235, 255, 7]},\n    {'id': 12, 'name': 'pole', 'isthing': 0, 'color': [150, 5, 61]},\n    {'id': 13, 'name': 'floor', 'isthing': 0, 'color': [120, 120, 70]},\n    {'id': 14, 'name': 'ground', 'isthing': 0, 'color': [8, 255, 51]},\n    {'id': 15, 'name': 'grass', 'isthing': 0, 'color': [255, 6, 82]},\n    {'id': 16, 'name': 'sand', 'isthing': 0, 'color': [143, 255, 140]},\n    {'id': 17, 'name': 'athletic_field', 'isthing': 0, 'color': [204, 255, 4]},\n    {'id': 18, 'name': 'road', 'isthing': 0, 'color': [255, 51, 7]},\n    {'id': 19, 'name': 'path', 'isthing': 0, 'color': [204, 70, 3]},\n    {'id': 20, 'name': 'crosswalk', 'isthing': 0, 'color': [0, 102, 200]},\n    {'id': 21, 'name': 'building', 'isthing': 0, 'color': [61, 230, 250]},\n    {'id': 22, 'name': 'house', 'isthing': 0, 'color': [255, 6, 51]},\n    {'id': 23, 'name': 'bridge', 'isthing': 0, 'color': [11, 102, 255]},\n    {'id': 24, 'name': 'tower', 'isthing': 0, 'color': [255, 7, 71]},\n    {'id': 25, 'name': 'windmill', 'isthing': 0, 'color': [255, 9, 224]},\n    {'id': 26, 'name': 'well_or_well_lid', 'isthing': 0, 'color': [9, 7, 230]},\n    {'id': 27, 'name': 'other_construction', 'isthing': 0, 'color': [220, 220, 220]},\n    {'id': 28, 'name': 'sky', 'isthing': 0, 'color': [255, 9, 92]},\n    {'id': 29, 'name': 'mountain', 'isthing': 0, 'color': [112, 9, 255]},\n    {'id': 30, 'name': 'stone', 'isthing': 0, 'color': [8, 255, 214]},\n    {'id': 31, 'name': 'wood', 'isthing': 0, 'color': [7, 255, 224]},\n    {'id': 32, 'name': 'ice', 'isthing': 0, 'color': [255, 184, 6]},\n    {'id': 33, 'name': 'snowfield', 'isthing': 0, 'color': [10, 255, 71]},\n    {'id': 34, 'name': 'grandstand', 'isthing': 0, 'color': [255, 41, 10]},\n    {'id': 35, 'name': 'sea', 'isthing': 0, 'color': [7, 255, 255]},\n    {'id': 36, 'name': 'river', 'isthing': 0, 'color': [224, 255, 8]},\n    {'id': 37, 'name': 'lake', 'isthing': 0, 'color': [102, 8, 255]},\n    {'id': 38, 'name': 'waterfall', 'isthing': 0, 'color': [255, 61, 6]},\n    {'id': 39, 'name': 'water', 'isthing': 0, 'color': [255, 194, 7]},\n    {'id': 40, 'name': 'billboard_or_Bulletin_Board', 'isthing': 0, 'color': [255, 122, 8]},\n    {'id': 42, 'name': 'pipeline', 'isthing': 0, 'color': [255, 8, 41]},\n    {'id': 45, 'name': 'cushion_or_carpet', 'isthing': 0, 'color': [235, 12, 255]},\n    {'id': 53, 'name': 'wheeled_machine', 'isthing': 0, 'color': [255, 224, 0]},\n    {'id': 57, 'name': 'tyre', 'isthing': 0, 'color': [0, 235, 255]},\n    {'id': 58, 'name': 'traffic_light', 'isthing': 0, 'color': [0, 173, 255]},\n    {'id': 59, 'name': 'lamp', 'isthing': 0, 'color': [31, 0, 255]},\n    {'id': 66, 'name': 'tree', 'isthing': 0, 'color': [255, 0, 0]},\n    {'id': 67, 'name': 'flower', 'isthing': 0, 'color': [255, 163, 0]},\n    {'id': 68, 'name': 'other_plant', 'isthing': 0, 'color': [255, 102, 0]},\n    {'id': 69, 'name': 'toy', 'isthing': 0, 'color': [194, 255, 0]},\n    {'id': 70, 'name': 'ball_net', 'isthing': 0, 'color': [0, 143, 255]},\n    {'id': 71, 'name': 'backboard', 'isthing': 0, 'color': [51, 255, 0]},\n    {'id': 73, 'name': 'bat', 'isthing': 0, 'color': [0, 255, 41]},\n    {'id': 75, 'name': 'cupboard_or_showcase_or_storage_rack', 'isthing': 0, 'color': [10, 0, 255]},\n    {'id': 80, 'name': 'trash_can', 'isthing': 0, 'color': [255, 0, 245]},\n    {'id': 81, 'name': 'cage', 'isthing': 0, 'color': [255, 0, 102]},\n    {'id': 93, 'name': 'shelf', 'isthing': 0, 'color': [51, 0, 255]},\n    {'id': 94, 'name': 'bathtub', 'isthing': 0, 'color': [0, 194, 255]},\n    {'id': 98, 'name': 'other_machine', 'isthing': 0, 'color': [0, 255, 10]},\n    {'id': 103, 'name': 'curtain', 'isthing': 0, 'color': [255, 235, 0]},\n    {'id': 104, 'name': 'textiles', 'isthing': 0, 'color': [8, 184, 170]},\n    {'id': 105, 'name': 'clothes', 'isthing': 0, 'color': [133, 0, 255]},\n    {'id': 110, 'name': 'book', 'isthing': 0, 'color': [0, 214, 255]},\n    {'id': 111, 'name': 'tool', 'isthing': 0, 'color': [255, 0, 112]},\n    {'id': 112, 'name': 'blackboard', 'isthing': 0, 'color': [92, 255, 0]},\n    {'id': 113, 'name': 'tissue', 'isthing': 0, 'color': [0, 224, 255]},\n    {'id': 119, 'name': 'other_electronic_product', 'isthing': 0, 'color': [255, 0, 163]},\n    {'id': 120, 'name': 'fruit', 'isthing': 0, 'color': [255, 204, 0]},\n    {'id': 121, 'name': 'food', 'isthing': 0, 'color': [255, 0, 143]}\n]\n\n# stuff -> thing\nNO_OBJ = 0\nNO_OBJ_HB = 255\nDIVISOR_PAN = 100\nDIVISOR_NEW = 1000\nNUM_THING = 58\nNUM_STUFF = 66\nTHING_B_STUFF = False\n\n\ndef vip2hb(pan_map):\n    assert not THING_B_STUFF, \"VIPSeg only supports stuff -> thing\"\n    pan_new = - np.ones_like(pan_map)\n    vip2hb_thing = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_THING)}\n    vip2hb_stuff = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_STUFF)}\n    for idx in np.unique(pan_map):\n        if idx == NO_OBJ or idx == 200:\n            pan_new[pan_map == idx] = NO_OBJ_HB\n        elif idx > 128:\n            cls_id = idx // DIVISOR_PAN\n            cls_new_id = vip2hb_thing[cls_id]\n            inst_id = idx % DIVISOR_PAN\n            # since stuff -> thing\n            cls_new_id += NUM_STUFF\n            pan_new[pan_map == idx] = cls_new_id * DIVISOR_NEW + inst_id\n        else:\n            pan_new[pan_map == idx] = vip2hb_stuff[idx]\n    assert -1. not in np.unique(pan_new)\n    return pan_new\n\n\nclass SeqObj:\n    # This divisor is orthogonal with panoptic class-instance divisor.\n    DIVISOR = 1000000\n\n    def __init__(self, the_dict: Dict):\n        self.dict = the_dict\n        assert 'seq_id' in self.dict and 'img_id' in self.dict\n\n    def __hash__(self):\n        return self.dict['seq_id'] * self.DIVISOR + self.dict['img_id']\n\n    def __eq__(self, other):\n        return self.dict['seq_id'] == other.dict['seq_id'] and self.dict['img_id'] == other.dict['img_id']\n\n    def __getitem__(self, attr):\n        return self.dict[attr]\n\n\n@DATASETS.register_module()\nclass VIPSegDVPSDataset:\n    CLASSES = (\n        'dummy'\n    )\n\n    def __init__(self,\n                 pipeline=None,\n                 data_root=None,\n                 test_mode=False,\n                 split='train',\n                 ref_seq_index: List[int] = None,\n                 is_instance_only: bool = True,\n                 ):\n        logger = get_root_logger()\n\n        assert data_root is not None\n        data_root = os.path.expanduser(data_root)\n        img_root = os.path.join(data_root, 'images')\n        seg_root = os.path.join(data_root, 'panomasks')\n        assert os.path.exists(img_root)\n        assert os.path.exists(seg_root)\n\n        # read split file\n        split_file = os.path.join(data_root, split + '.txt')\n        video_folders = mmcv.list_from_file(split_file, prefix=img_root + '/')\n        ann_folders = mmcv.list_from_file(split_file, prefix=seg_root + '/')\n        logger.info(\"VIPSegDVPSDataset : There are totally {} videos in {} split.\".format(len(video_folders), split))\n\n        # 58 things and 66 stuff, totally 124 classes\n        self.num_thing_classes = 58\n        self.num_stuff_classes = 66\n        assert len(CLASSES_THING) == self.num_thing_classes\n        assert len(CLASSES_STUFF) == self.num_stuff_classes\n        assert len(CLASSES) == self.num_thing_classes + self.num_stuff_classes\n        self.thing_before_stuff = False\n\n        # ref_seq_index is None means no ref img\n        if ref_seq_index is None:\n            ref_seq_index = []\n\n        images = []\n        # remember that both img_id and seq_id start from 0\n        _tmp_seq_id = -1\n        for vid_folder, ann_folder in zip(video_folders, ann_folders):\n            assert os.path.basename(vid_folder) == os.path.basename(ann_folder)\n            _tmp_seq_id += 1\n            _tmp_img_id = -1\n            imgs_cur = sorted(list(map(lambda x: str(x), mmcv.scandir(vid_folder, recursive=False, suffix='.jpg'))))\n            pans_cur = sorted(list(map(lambda x: str(x), mmcv.scandir(ann_folder, recursive=False, suffix='.png'))))\n            for img_cur, pan_cur in zip(imgs_cur, pans_cur):\n                assert img_cur.split('.')[0] == pan_cur.split('.')[0]\n                _tmp_img_id += 1\n                seq_id = _tmp_seq_id\n                img_id = _tmp_img_id\n                item_full = os.path.join(vid_folder, img_cur)\n                inst_map = os.path.join(ann_folder, pan_cur)\n                images.append(SeqObj({\n                    'seq_id': int(seq_id),\n                    'img_id': int(img_id),\n                    'img': item_full,\n                    'ann': inst_map,\n                    'no_obj_class': 255\n                }))\n                assert os.path.exists(images[-1]['img'])\n                assert os.path.exists(images[-1]['ann'])\n\n        # Warning from Haobo: the following codes are dangerous\n        # because they rely on a consistent seed among different\n        # processes. Please contact me before using it.\n        reference_images = {hash(image): image for image in images}\n        sequences = []\n        for img_cur in images:\n            is_seq = True\n            seq_now = [img_cur.dict]\n            if ref_seq_index:\n                for index in random.choices(ref_seq_index, k=1):\n                    query_obj = SeqObj({\n                        'seq_id': img_cur.dict['seq_id'],\n                        'img_id': img_cur.dict['img_id'] + index\n                    })\n                    if hash(query_obj) in reference_images:\n                        seq_now.append(reference_images[hash(query_obj)].dict)\n                    else:\n                        is_seq = False\n                        break\n            if is_seq:\n                sequences.append(seq_now)\n\n        self.sequences = sequences\n        self.ref_seq_index = ref_seq_index\n        logger.info(\"VIPSegDVPSDataset : There are totally {} clips in {} split for training.\".format(\n            len(self.sequences), split))\n\n        # mmdet\n        self.pipeline = Compose(pipeline)\n        self.test_mode = test_mode\n\n        # misc\n        self.flag = self._set_groups()\n        self.is_instance_only = is_instance_only\n\n        # For evaluation\n        self.max_ins = 1000\n        self.no_obj_id = 255\n\n    def pre_pipelines(self, results):\n        for _results in results:\n            _results['img_info'] = []\n            _results['thing_lower'] = 0 if self.thing_before_stuff else self.num_stuff_classes\n            _results['thing_upper'] = self.num_thing_classes \\\n                if self.thing_before_stuff else self.num_stuff_classes + self.num_thing_classes\n            _results['is_instance_only'] = self.is_instance_only\n            _results['ori_filename'] = os.path.basename(_results['img'])\n            _results['filename'] = _results['img']\n            _results['pre_hook'] = vip2hb\n\n    def prepare_train_img(self, idx):\n        \"\"\"Get training data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training data and annotation after pipeline with new keys \\\n                introduced by pipeline.\n        \"\"\"\n        results = copy.deepcopy(self.sequences[idx])\n        self.pre_pipelines(results)\n        return self.pipeline(results)\n\n    def prepare_test_img(self, idx):\n        results = copy.deepcopy(self.sequences[idx])\n        self.pre_pipelines(results)\n        # During test time, one image inference does not requires seq\n        if not self.ref_seq_index:\n            results = results[0]\n        return self.pipeline(results)\n\n    def _rand_another(self, idx):\n        \"\"\"Get another random index from the same group as the given index.\"\"\"\n        pool = np.where(self.flag == self.flag[idx])[0]\n        return np.random.choice(pool)\n\n    # Copy and Modify from mmdet\n    def __getitem__(self, idx):\n        \"\"\"Get training/test data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training/test data (with annotation if `test_mode` is set \\\n                True).\n        \"\"\"\n\n        if self.test_mode:\n            return self.prepare_test_img(idx)\n        else:\n            while True:\n                cur_data = self.prepare_train_img(idx)\n                if cur_data is None:\n                    idx = self._rand_another(idx)\n                    continue\n                return cur_data\n\n    def __len__(self):\n        \"\"\"Total number of samples of data.\"\"\"\n        return len(self.sequences)\n\n    def _set_groups(self):\n        return np.zeros((len(self)), dtype=np.int64)\n\n    # The evaluate func\n    def evaluate(\n            self,\n            results,\n            **kwargs\n    ):\n        raise NotImplementedError\n\n\nif __name__ == '__main__':\n    import dataset.dvps_pipelines.loading\n    import dataset.dvps_pipelines.transforms\n    import dataset.pipelines.transforms\n    import dataset.pipelines.formatting\n    import dataset.dvps_pipelines.tricks\n\n    img_norm_cfg = dict(\n        mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)\n\n    test_pipeline = [\n        dict(type='LoadMultiImagesDirect'),\n        dict(type='SeqPadWithDepth', size_divisor=32),\n        dict(type='SeqNormalize', **img_norm_cfg),\n        dict(\n            type='VideoCollect',\n            keys=['img']),\n        dict(type='ConcatVideoReferences'),\n        dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n    ]\n\n    _auto_aug_polices = [\n        [\n            dict(type='ColorTransform', prob=0.5, level=3),\n            dict(type='EqualizeTransform', prob=0.5),\n            dict(type='BrightnessTransform', prob=0.5, level=3),\n            dict(type='ContrastTransform', prob=0.5, level=3),\n        ],\n        [\n            dict(type='EqualizeTransform', prob=0),\n        ]\n    ]\n\n    data = VIPSegDVPSDataset(\n        pipeline=[\n            dict(type='LoadMultiImagesDirect'),\n            dict(type='LoadMultiAnnotationsDirect', with_depth=False, vipseg=True),\n            dict(type='SeqAutoAug', policies=_auto_aug_polices),\n            dict(type='SeqResizeWithDepth', img_scale=(720, 100000), ratio_range=[1., 2.], keep_ratio=True),\n            dict(type='SeqFlipWithDepth', flip_ratio=0.5),\n            dict(type='SeqRandomCropWithDepth', crop_size=(736, 736), share_params=True),\n            dict(type='SeqPadWithDepth', size_divisor=32),\n            dict(type='SeqNormalize', **img_norm_cfg),\n            dict(\n                type='VideoCollect',\n                keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),\n            dict(type='ConcatVideoReferences'),\n            dict(type='SeqDefaultFormatBundle', ref_prefix='ref'),\n        ],\n        data_root=\"data/VIPSeg\",\n        test_mode=False,\n        split='train',\n        ref_seq_index=[-1, 1],\n        is_instance_only=False,\n    )\n    np.set_string_function(lambda x: '<{} ; {}>'.format(x.shape, x.dtype))\n    torch.set_printoptions(profile='short')\n    for item in data:\n        print(item)\n"
  },
  {
    "path": "knet/__init__.py",
    "content": ""
  },
  {
    "path": "knet/cross_entropy_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmdet.models.builder import LOSSES\nfrom mmdet.models.losses.utils import weight_reduce_loss\n\n\ndef cross_entropy(pred,\n                  label,\n                  weight=None,\n                  reduction='mean',\n                  avg_factor=None,\n                  class_weight=None,\n                  ignore_index=-100):\n    \"\"\"Calculate the CrossEntropy loss.\n    Args:\n        pred (torch.Tensor): The prediction with shape (N, C), C is the number\n            of classes.\n        label (torch.Tensor): The learning label of the prediction.\n        weight (torch.Tensor, optional): Sample-wise loss weight.\n        reduction (str, optional): The method used to reduce the loss.\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. Defaults to None.\n        class_weight (list[float], optional): The weight for each class.\n    Returns:\n        torch.Tensor: The calculated loss\n    \"\"\"\n    # element-wise losses\n    loss = F.cross_entropy(\n        pred,\n        label,\n        weight=class_weight,\n        reduction='none',\n        ignore_index=ignore_index)\n\n    # apply weights and do the reduction\n    if weight is not None:\n        weight = weight.float()\n    loss = weight_reduce_loss(\n        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)\n\n    return loss\n\n\ndef _expand_onehot_labels(labels, label_weights, label_channels):\n    bin_labels = labels.new_full((labels.size(0), label_channels), 0)\n    inds = torch.nonzero(\n        (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()\n    if inds.numel() > 0:\n        bin_labels[inds, labels[inds]] = 1\n\n    if label_weights is None:\n        bin_label_weights = None\n    else:\n        bin_label_weights = label_weights.view(-1, 1).expand(\n            label_weights.size(0), label_channels)\n\n    return bin_labels, bin_label_weights\n\n\ndef binary_cross_entropy(pred,\n                         label,\n                         weight=None,\n                         reduction='mean',\n                         avg_factor=None,\n                         class_weight=None):\n    \"\"\"Calculate the binary CrossEntropy loss.\n    Args:\n        pred (torch.Tensor): The prediction with shape (N, 1).\n        label (torch.Tensor): The learning label of the prediction.\n        weight (torch.Tensor, optional): Sample-wise loss weight.\n        reduction (str, optional): The method used to reduce the loss.\n            Options are \"none\", \"mean\" and \"sum\".\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. Defaults to None.\n        class_weight (list[float], optional): The weight for each class.\n    Returns:\n        torch.Tensor: The calculated loss\n    \"\"\"\n    if pred.dim() != label.dim():\n        label, weight = _expand_onehot_labels(label, weight, pred.size(-1))\n\n    # weighted element-wise losses\n    if weight is not None:\n        weight = weight.float()\n    loss = F.binary_cross_entropy_with_logits(\n        pred, label.float(), pos_weight=class_weight, reduction='none')\n    # do the reduction for the weighted loss\n    loss = weight_reduce_loss(\n        loss, weight, reduction=reduction, avg_factor=avg_factor)\n\n    return loss\n\n\ndef mask_cross_entropy(pred,\n                       target,\n                       label,\n                       reduction='mean',\n                       avg_factor=None,\n                       class_weight=None):\n    \"\"\"Calculate the CrossEntropy loss for masks.\n    Args:\n        pred (torch.Tensor): The prediction with shape (N, C, *), C is the\n            number of classes. The trailing * indicates arbitrary shape.\n        target (torch.Tensor): The learning label of the prediction.\n        label (torch.Tensor): ``label`` indicates the class label of the mask\n            corresponding object. This will be used to select the mask in the\n            of the class which the object belongs to when the mask prediction\n            if not class-agnostic.\n        reduction (str, optional): The method used to reduce the loss.\n            Options are \"none\", \"mean\" and \"sum\".\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. Defaults to None.\n        class_weight (list[float], optional): The weight for each class.\n    Returns:\n        torch.Tensor: The calculated loss\n    Example:\n        >>> N, C = 3, 11\n        >>> H, W = 2, 2\n        >>> pred = torch.randn(N, C, H, W) * 1000\n        >>> target = torch.rand(N, H, W)\n        >>> label = torch.randint(0, C, size=(N,))\n        >>> reduction = 'mean'\n        >>> avg_factor = None\n        >>> class_weights = None\n        >>> loss = mask_cross_entropy(pred, target, label, reduction,\n        >>>                           avg_factor, class_weights)\n        >>> assert loss.shape == (1,)\n    \"\"\"\n    # TODO: handle these two reserved arguments\n    assert reduction == 'mean' and avg_factor is None\n    num_rois = pred.size()[0]\n    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)\n    pred_slice = pred[inds, label].squeeze(1)\n    return F.binary_cross_entropy_with_logits(\n        pred_slice, target, weight=class_weight, reduction='mean')[None]\n\n\n@LOSSES.register_module(force=True)\nclass CrossEntropyLoss(nn.Module):\n\n    def __init__(self,\n                 use_sigmoid=False,\n                 use_mask=False,\n                 reduction='mean',\n                 class_weight=None,\n                 loss_weight=1.0):\n        \"\"\"CrossEntropyLoss.\n        Args:\n            use_sigmoid (bool, optional): Whether the prediction uses sigmoid\n                of softmax. Defaults to False.\n            use_mask (bool, optional): Whether to use mask cross entropy loss.\n                Defaults to False.\n            reduction (str, optional): . Defaults to 'mean'.\n                Options are \"none\", \"mean\" and \"sum\".\n            class_weight (list[float], optional): Weight of each class.\n                Defaults to None.\n            loss_weight (float, optional): Weight of the loss. Defaults to 1.0.\n        \"\"\"\n        super(CrossEntropyLoss, self).__init__()\n        assert (use_sigmoid is False) or (use_mask is False)\n        self.use_sigmoid = use_sigmoid\n        self.use_mask = use_mask\n        self.reduction = reduction\n        self.loss_weight = loss_weight\n        self.class_weight = class_weight\n\n        if self.use_sigmoid:\n            self.cls_criterion = binary_cross_entropy\n        elif self.use_mask:\n            self.cls_criterion = mask_cross_entropy\n        else:\n            self.cls_criterion = cross_entropy\n\n    def forward(self,\n                cls_score,\n                label,\n                weight=None,\n                avg_factor=None,\n                reduction_override=None,\n                **kwargs):\n        \"\"\"Forward function.\n        Args:\n            cls_score (torch.Tensor): The prediction.\n            label (torch.Tensor): The learning label of the prediction.\n            weight (torch.Tensor, optional): Sample-wise loss weight.\n            avg_factor (int, optional): Average factor that is used to average\n                the loss. Defaults to None.\n            reduction (str, optional): The method used to reduce the loss.\n                Options are \"none\", \"mean\" and \"sum\".\n        Returns:\n            torch.Tensor: The calculated loss\n        \"\"\"\n        assert reduction_override in (None, 'none', 'mean', 'sum')\n        reduction = (\n            reduction_override if reduction_override else self.reduction)\n        if self.class_weight is not None:\n            class_weight = cls_score.new_tensor(\n                self.class_weight, device=cls_score.device)\n        else:\n            class_weight = None\n        loss_cls = self.loss_weight * self.cls_criterion(\n            cls_score,\n            label,\n            weight,\n            class_weight=class_weight,\n            reduction=reduction,\n            avg_factor=avg_factor,\n            **kwargs)\n        return loss_cls\n"
  },
  {
    "path": "knet/det/dice_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmdet.models.builder import LOSSES, build_loss\nfrom mmdet.models.losses.utils import weighted_loss\n\n\n@weighted_loss\ndef dice_loss(input, target, eps=1e-3, numerator_eps=0):\n    input = input.reshape(input.size()[0], -1)\n    target = target.reshape(target.size()[0], -1).float()\n\n    a = torch.sum(input * target, 1)\n    b = torch.sum(input * input, 1) + eps\n    c = torch.sum(target * target, 1) + eps\n    d = (2 * a + numerator_eps) / (b + c)\n    return 1 - d\n\n#\n# @LOSSES.register_module()\n# class DiceLoss(nn.Module):\n#\n#     def __init__(self,\n#                  eps=1e-3,\n#                  numerator_eps=0.0,\n#                  use_sigmoid=True,\n#                  reduction='mean',\n#                  loss_weight=1.0):\n#         super(DiceLoss, self).__init__()\n#         self.eps = eps\n#         self.reduction = reduction\n#         self.loss_weight = loss_weight\n#         self.use_sigmoid = use_sigmoid\n#         self.numerator_eps = numerator_eps\n#\n#     def forward(self,\n#                 pred,\n#                 target,\n#                 weight=None,\n#                 avg_factor=None,\n#                 reduction_override=None,\n#                 **kwargs):\n#         if weight is not None and not torch.any(weight > 0):\n#             return (pred * weight).sum()  # 0\n#         assert reduction_override in (None, 'none', 'mean', 'sum')\n#         reduction = (\n#             reduction_override if reduction_override else self.reduction)\n#         pred = pred.sigmoid()\n#         loss = self.loss_weight * dice_loss(\n#             pred,\n#             target,\n#             weight,\n#             eps=self.eps,\n#             numerator_eps=self.numerator_eps,\n#             reduction=reduction,\n#             avg_factor=avg_factor,\n#             **kwargs)\n#         return loss\n"
  },
  {
    "path": "knet/det/kernel_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, normal_init)\nfrom mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean\nfrom mmdet.models.builder import HEADS, build_loss, build_neck\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\n\n@HEADS.register_module()\nclass ConvKernelHead(nn.Module):\n\n    def __init__(self,\n                 num_proposals=100,\n                 in_channels=256,\n                 out_channels=256,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_seg_convs=1,\n                 num_loc_convs=1,\n                 att_dropout=False,\n                 localization_fpn=None,\n                 conv_kernel_size=1,\n                 norm_cfg=dict(type='GN', num_groups=32),\n                 semantic_fpn=True,\n                 train_cfg=None,\n                 num_classes=80,\n                 xavier_init_kernel=False,\n                 kernel_init_std=0.01,\n                 use_binary=False,\n                 proposal_feats_with_obj=False,\n                 loss_mask=None,\n                 loss_seg=None,\n                 loss_cls=None,\n                 loss_dice=None,\n                 loss_rank=None,\n                 feat_downsample_stride=1,\n                 feat_refine_stride=1,\n                 feat_refine=True,\n                 with_embed=False,\n                 feat_embed_only=False,\n                 conv_normal_init=False,\n                 mask_out_stride=4,\n                 hard_target=False,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cat_stuff_mask=False,\n                 **kwargs):\n        super(ConvKernelHead, self).__init__()\n        self.num_proposals = num_proposals\n        self.num_cls_fcs = num_cls_fcs\n        self.train_cfg = train_cfg\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_classes = num_classes\n        self.proposal_feats_with_obj = proposal_feats_with_obj\n        self.sampling = False\n        self.localization_fpn = build_neck(localization_fpn)\n        self.semantic_fpn = semantic_fpn\n        self.norm_cfg = norm_cfg\n        self.num_heads = num_heads\n        self.att_dropout = att_dropout\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.conv_kernel_size = conv_kernel_size\n        self.xavier_init_kernel = xavier_init_kernel\n        self.kernel_init_std = kernel_init_std\n        self.feat_downsample_stride = feat_downsample_stride\n        self.feat_refine_stride = feat_refine_stride\n        self.conv_normal_init = conv_normal_init\n        self.feat_refine = feat_refine\n        self.with_embed = with_embed\n        self.feat_embed_only = feat_embed_only\n        self.num_loc_convs = num_loc_convs\n        self.num_seg_convs = num_seg_convs\n        self.use_binary = use_binary\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n        self.cat_stuff_mask = cat_stuff_mask\n\n        if loss_mask is not None:\n            self.loss_mask = build_loss(loss_mask)\n        else:\n            self.loss_mask = loss_mask\n\n        if loss_dice is not None:\n            self.loss_dice = build_loss(loss_dice)\n        else:\n            self.loss_dice = loss_dice\n\n        if loss_seg is not None:\n            self.loss_seg = build_loss(loss_seg)\n        else:\n            self.loss_seg = loss_seg\n        if loss_cls is not None:\n            self.loss_cls = build_loss(loss_cls)\n        else:\n            self.loss_cls = loss_cls\n\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        if self.train_cfg:\n            self.assigner = build_assigner(self.train_cfg.assigner)\n            # use PseudoSampler when sampling is False\n            if self.sampling and hasattr(self.train_cfg, 'sampler'):\n                sampler_cfg = self.train_cfg.sampler\n            else:\n                sampler_cfg = dict(type='MaskPseudoSampler')\n            self.sampler = build_sampler(sampler_cfg, context=self)\n        self._init_layers()\n\n    def _init_layers(self):\n        \"\"\"Initialize a sparse set of proposal boxes and proposal features.\"\"\"\n        self.init_kernels = nn.Conv2d(\n            self.out_channels,\n            self.num_proposals,\n            self.conv_kernel_size,\n            padding=int(self.conv_kernel_size // 2),\n            bias=False)  # (N, C)\n\n        if self.semantic_fpn:\n            self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes, 1)\n\n\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            self.ins_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,  # 2\n                padding=1,\n                norm_cfg=self.norm_cfg)\n            self.seg_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,  # 2\n                padding=1,\n                norm_cfg=self.norm_cfg)\n\n        self.loc_convs = nn.ModuleList()\n        for i in range(self.num_loc_convs):\n            self.loc_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n        self.seg_convs = nn.ModuleList()\n        for i in range(self.num_seg_convs):\n            self.seg_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        self.localization_fpn.init_weights()\n\n        if self.feat_downsample_stride > 1 and self.conv_normal_init:\n            logger = get_root_logger()\n            logger.info('Initialize convs in KPN head by normal std 0.01')\n            for conv in [self.loc_convs, self.seg_convs]:\n                for m in conv.modules():\n                    if isinstance(m, nn.Conv2d):\n                        normal_init(m, std=0.01)\n\n        if self.semantic_fpn:\n            bias_seg = bias_init_with_prob(0.01)\n            if self.loss_seg.use_sigmoid:\n                normal_init(self.conv_seg, std=0.01, bias=bias_seg)\n            else:\n                normal_init(self.conv_seg, mean=0, std=0.01)\n        if self.xavier_init_kernel:\n            logger = get_root_logger()\n            logger.info('Initialize kernels by xavier uniform')\n            nn.init.xavier_uniform_(self.init_kernels.weight)\n        else:\n            logger = get_root_logger()\n            logger.info(\n                f'Initialize kernels by normal std: {self.kernel_init_std}')\n            normal_init(self.init_kernels, mean=0, std=self.kernel_init_std)\n\n    def _decode_init_proposals(self, img, img_metas):\n        num_imgs = len(img_metas)\n\n        localization_feats = self.localization_fpn(img)\n\n        ## thing branch\n        if isinstance(localization_feats, list):\n            loc_feats = localization_feats[0]\n        else:\n            loc_feats = localization_feats\n        for conv in self.loc_convs:\n            loc_feats = conv(loc_feats)\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            loc_feats = self.ins_downsample(loc_feats)\n\n        # init kernel prediction\n        mask_preds = self.init_kernels(loc_feats)\n\n        # stuff branch\n        if self.semantic_fpn:\n            if isinstance(localization_feats, list):\n                semantic_feats = localization_feats[1]\n            else:\n                semantic_feats = localization_feats\n            for conv in self.seg_convs:\n                semantic_feats = conv(semantic_feats)\n            if self.feat_downsample_stride > 1 and self.feat_refine:\n                semantic_feats = self.seg_downsample(semantic_feats)\n        else:\n            semantic_feats = None\n\n        if semantic_feats is not None:\n            seg_preds = self.conv_seg(semantic_feats)\n        else:\n            seg_preds = None\n\n        proposal_feats = self.init_kernels.weight.clone()\n        proposal_feats = proposal_feats[None].expand(num_imgs,\n                                                     *proposal_feats.size())\n\n        if semantic_feats is not None:\n            x_feats = semantic_feats + loc_feats\n        else:\n            x_feats = loc_feats\n\n        if self.proposal_feats_with_obj:\n            sigmoid_masks = mask_preds.sigmoid()\n            nonzero_inds = sigmoid_masks > 0.5\n            if self.use_binary:\n                sigmoid_masks = nonzero_inds.float()\n            else:\n                sigmoid_masks = nonzero_inds.float() * sigmoid_masks\n            obj_feats = torch.einsum('bnhw, bchw->bnc', sigmoid_masks, x_feats)\n\n        cls_scores = None\n\n        if self.proposal_feats_with_obj:  # important use\n            proposal_feats = proposal_feats + obj_feats.view(\n                num_imgs, self.num_proposals, self.out_channels, 1, 1)\n\n        if self.cat_stuff_mask and not self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs,\n                                                       *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)  # (b, N_{st}+N_{th}, c)\n\n        return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n        \"\"\"Forward function in training stage.\"\"\"\n        num_imgs = len(img_metas)\n        results = self._decode_init_proposals(img, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results\n        if self.feat_downsample_stride > 1:\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=self.feat_downsample_stride,\n                mode='bilinear',\n                align_corners=False)\n            if seg_preds is not None:\n                scaled_seg_preds = F.interpolate(\n                    seg_preds,\n                    scale_factor=self.feat_downsample_stride,\n                    mode='bilinear',\n                    align_corners=False)\n        else:\n            scaled_mask_preds = mask_preds  # thing\n            scaled_seg_preds = seg_preds   # stuff\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        sampling_results = []\n        if cls_scores is None:\n            detached_cls_scores = [None] * num_imgs\n        else:\n            detached_cls_scores = cls_scores.detach()\n\n        for i in range(num_imgs):\n            assign_result = self.assigner.assign(scaled_mask_preds[i].detach(),\n                                                 detached_cls_scores[i],\n                                                 gt_masks[i], gt_labels[i],\n                                                 img_metas[i])\n            sampling_result = self.sampler.sample(assign_result,\n                                                  scaled_mask_preds[i],\n                                                  gt_masks[i])\n            sampling_results.append(sampling_result)\n\n        mask_targets = self.get_targets(\n            sampling_results,\n            gt_masks,\n            self.train_cfg,\n            True,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls)\n\n        losses = self.loss(scaled_mask_preds, cls_scores, scaled_seg_preds,\n                           proposal_feats, *mask_targets)\n\n        if self.cat_stuff_mask and self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs,\n                                                       *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return losses, proposal_feats, x_feats, mask_preds, cls_scores\n\n    def loss(self,\n             mask_pred,\n             cls_scores,\n             seg_preds,\n             proposal_feats,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             seg_targets,\n             reduction_override=None,\n             **kwargs):\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n\n        if cls_scores is not None:\n            num_pos = pos_inds.sum().float()\n            avg_factor = reduce_mean(num_pos)\n            assert mask_pred.shape[0] == cls_scores.shape[0]\n            assert mask_pred.shape[1] == cls_scores.shape[1]\n            losses['loss_rpn_cls'] = self.loss_cls(\n                cls_scores.view(num_preds, -1),\n                labels,\n                label_weights,\n                avg_factor=avg_factor,\n                reduction_override=reduction_override)\n            losses['rpn_pos_acc'] = accuracy(\n                cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds])\n\n        bool_pos_inds = pos_inds.type(torch.bool)\n        # 0~self.num_classes-1 are FG, self.num_classes is BG\n        # do not perform bounding box regression for BG anymore.\n        H, W = mask_pred.shape[-2:]\n        if pos_inds.any():\n            pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds]\n            pos_mask_targets = mask_targets[bool_pos_inds]\n            losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n            losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n            if self.loss_rank is not None:\n                batch_size = mask_pred.size(0)\n                rank_target = mask_targets.new_full((batch_size, H, W),\n                                                    self.ignore_label,\n                                                    dtype=torch.long)\n                rank_inds = pos_inds.view(batch_size,\n                                          -1).nonzero(as_tuple=False)\n                batch_mask_targets = mask_targets.view(batch_size, -1, H,\n                                                       W).bool()\n                for i in range(batch_size):\n                    curr_inds = (rank_inds[:, 0] == i)\n                    curr_rank = rank_inds[:, 1][curr_inds]\n                    for j in curr_rank:\n                        rank_target[i][batch_mask_targets[i][j]] = j\n                losses['loss_rpn_rank'] = self.loss_rank(\n                    mask_pred, rank_target, ignore_index=self.ignore_label)\n\n        else:\n            losses['loss_rpn_mask'] = mask_pred.sum() * 0\n            losses['loss_rpn_dice'] = mask_pred.sum() * 0\n            if self.loss_rank is not None:\n                losses['loss_rank'] = mask_pred.sum() * 0\n\n        if seg_preds is not None:\n            # focal loss\n            if self.loss_seg.use_sigmoid:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(\n                    -1, cls_channel,\n                    H * W).permute(0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                num_dense_pos = (flatten_seg_target >= 0) & (\n                    flatten_seg_target < bg_class_ind)\n                num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0)\n                losses['loss_rpn_seg'] = self.loss_seg(\n                    flatten_seg,\n                    flatten_seg_target,\n                    avg_factor=num_dense_pos)\n            # ce loss\n            else:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(\n                    0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                losses['loss_rpn_seg'] = self.loss_seg(flatten_seg,\n                                                       flatten_seg_target, ignore_index=self.num_classes)\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros(num_samples)\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        seg_targets = pos_mask.new_full((H, W),\n                                        self.num_classes,\n                                        dtype=torch.long)\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            gt_sem_seg = gt_sem_seg.bool()\n            for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls):\n                seg_targets[sem_mask] = sem_cls.long()\n\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            mask_targets[pos_inds, ...] = pos_gt_mask\n            mask_weights[pos_inds, ...] = 1\n            for i in range(num_pos):\n                seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i]\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def get_targets(self,\n                    sampling_results,\n                    gt_mask,\n                    rpn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * 2\n            gt_sem_cls = [None] * 2\n        results = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rpn_train_cfg)\n        (labels, label_weights, mask_targets, mask_weights,\n         seg_targets) = results\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n            seg_targets = torch.stack(seg_targets, 0)\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def simple_test_rpn(self, img, img_metas):\n        \"\"\"Forward function in testing stage.\"\"\"\n        return self._decode_init_proposals(img, img_metas)\n\n    def forward_dummy(self, img, img_metas):\n        \"\"\"Dummy forward function.\n\n        Used in flops calculation.\n        \"\"\"\n        return self._decode_init_proposals(img, img_metas)\n"
  },
  {
    "path": "knet/det/kernel_iter_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmdet.core import build_assigner, build_sampler\nfrom mmdet.models.builder import HEADS, build_head\nfrom mmdet.models.roi_heads import BaseRoIHead\n\nfrom .mask_pseudo_sampler import MaskPseudoSampler\n\n\n@HEADS.register_module()\nclass KernelIterHead(BaseRoIHead):\n\n    def __init__(self,\n                 num_stages=6,\n                 recursive=False,\n                 assign_stages=5,\n                 stage_loss_weights=(1, 1, 1, 1, 1, 1),\n                 proposal_feature_channel=256,\n                 merge_cls_scores=False,\n                 do_panoptic=False,\n                 post_assign=False,\n                 hard_target=False,\n                 merge_joint=False,\n                 num_proposals=100,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 mask_head=dict(\n                     type='KernelUpdateHead',\n                     num_classes=80,\n                     num_fcs=2,\n                     num_heads=8,\n                     num_cls_fcs=1,\n                     num_reg_fcs=3,\n                     feedforward_channels=2048,\n                     hidden_channels=256,\n                     dropout=0.0,\n                     roi_feat_size=7,\n                     ffn_act_cfg=dict(type='ReLU', inplace=True)),\n                 mask_out_stride=4,\n                 train_cfg=None,\n                 test_cfg=None,\n                 **kwargs):\n        assert mask_head is not None\n        assert len(stage_loss_weights) == num_stages\n        self.num_stages = num_stages\n        self.stage_loss_weights = stage_loss_weights\n        self.proposal_feature_channel = proposal_feature_channel\n        self.merge_cls_scores = merge_cls_scores\n        self.recursive = recursive\n        self.post_assign = post_assign\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.assign_stages = assign_stages\n        self.do_panoptic = do_panoptic\n        self.merge_joint = merge_joint\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.num_classes = self.num_thing_classes + self.num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.num_proposals = num_proposals\n        self.ignore_label = ignore_label\n        super(KernelIterHead, self).__init__(\n            mask_head=mask_head, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)\n        # train_cfg would be None when run the test.py\n        if train_cfg is not None:\n            for stage in range(num_stages):\n                assert isinstance(\n                    self.mask_sampler[stage], MaskPseudoSampler), \\\n                    'Sparse Mask only support `MaskPseudoSampler`'\n\n    def init_bbox_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize box head and box roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of box roi extractor.\n            mask_head (dict): Config of box in box head.\n        \"\"\"\n        pass\n\n    def init_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler for each stage.\"\"\"\n        self.mask_assigner = []\n        self.mask_sampler = []\n        if self.train_cfg is not None:\n            for idx, rcnn_train_cfg in enumerate(self.train_cfg):\n                self.mask_assigner.append(\n                    build_assigner(rcnn_train_cfg.assigner))\n                self.current_stage = idx\n                self.mask_sampler.append(\n                    build_sampler(rcnn_train_cfg.sampler, context=self))\n\n    def init_weights(self):\n        for i in range(self.num_stages):\n            self.mask_head[i].init_weights()\n\n    def init_mask_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize mask head and mask roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of mask roi extractor.\n            mask_head (dict): Config of mask in mask head.\n        \"\"\"\n        self.mask_head = nn.ModuleList()\n        if not isinstance(mask_head, list):\n            mask_head = [mask_head for _ in range(self.num_stages)]\n        assert len(mask_head) == self.num_stages\n        for head in mask_head:\n            self.mask_head.append(build_head(head))\n        if self.recursive:\n            for i in range(self.num_stages):\n                self.mask_head[i] = self.mask_head[0]\n\n    def _mask_forward(self, stage, x, object_feats, mask_preds, img_metas):\n        mask_head = self.mask_head[stage]\n        cls_score, mask_preds, object_feats = mask_head(\n            x, object_feats, mask_preds, img_metas=img_metas)\n        if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1\n                                                   or self.training):\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=mask_head.mask_upsample_stride,\n                align_corners=False,\n                mode='bilinear')\n        else:\n            scaled_mask_preds = mask_preds\n        mask_results = dict(\n            cls_score=cls_score,\n            mask_preds=mask_preds,\n            scaled_mask_preds=scaled_mask_preds,\n            object_feats=object_feats)\n\n        return mask_results\n\n    def forward_train(self,\n                      x,\n                      proposal_feats,\n                      mask_preds,\n                      cls_score,\n                      img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_bboxes_ignore=None,\n                      imgs_whwh=None,\n                      gt_bboxes=None,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n\n        num_imgs = len(img_metas)\n        if self.mask_head[0].mask_upsample_stride > 1:\n            prev_mask_preds = F.interpolate(\n                mask_preds.detach(),\n                scale_factor=self.mask_head[0].mask_upsample_stride,\n                mode='bilinear',\n                align_corners=False)\n        else:\n            prev_mask_preds = mask_preds.detach()\n\n        if cls_score is not None:\n            prev_cls_score = cls_score.detach()\n        else:\n            prev_cls_score = [None] * num_imgs\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        object_feats = proposal_feats\n        all_stage_loss = {}\n        all_stage_mask_results = []\n        assign_results = []\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score']\n            object_feats = mask_results['object_feats']\n\n            if self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            sampling_results = []\n            if stage < self.assign_stages:\n                assign_results = []\n            for i in range(num_imgs):\n                if stage < self.assign_stages:\n                    mask_for_assign = prev_mask_preds[i][:self.num_proposals]\n                    if prev_cls_score[i] is not None:\n                        cls_for_assign = prev_cls_score[\n                            i][:self.num_proposals, :self.num_thing_classes]\n                    else:\n                        cls_for_assign = None\n                    assign_result = self.mask_assigner[stage].assign(\n                        mask_for_assign, cls_for_assign, gt_masks[i],\n                        gt_labels[i], img_metas[i])\n                    assign_results.append(assign_result)\n                sampling_result = self.mask_sampler[stage].sample(\n                    assign_results[i], scaled_mask_preds[i], gt_masks[i])\n                sampling_results.append(sampling_result)\n            mask_targets = self.mask_head[stage].get_targets(\n                sampling_results,\n                gt_masks,\n                gt_labels,\n                self.train_cfg[stage],\n                True,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls)\n\n            single_stage_loss = self.mask_head[stage].loss(\n                object_feats,\n                cls_score,\n                scaled_mask_preds,\n                *mask_targets,\n                imgs_whwh=imgs_whwh)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f's{stage}_{key}'] = value * \\\n                                    self.stage_loss_weights[stage]\n\n            if not self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n        return all_stage_loss\n\n    def simple_test(self,\n                    x,\n                    proposal_feats,\n                    mask_preds,\n                    cls_score,\n                    img_metas,\n                    imgs_whwh=None,\n                    rescale=False):\n\n        # Decode initial proposals\n        num_imgs = len(img_metas)\n        # num_proposals = proposal_feats.size(1)\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n\n        num_classes = self.mask_head[-1].num_classes\n        results = []\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        if self.do_panoptic:\n            for img_id in range(num_imgs):\n                single_result = self.get_panoptic(cls_score[img_id],\n                                                  scaled_mask_preds[img_id],\n                                                  self.test_cfg,\n                                                  img_metas[img_id])\n                results.append(single_result)\n        else:\n            for img_id in range(num_imgs):\n                cls_score_per_img = cls_score[img_id]\n                scores_per_img, topk_indices = cls_score_per_img.flatten(\n                    0, 1).topk(\n                        self.test_cfg.max_per_img, sorted=True)\n                mask_indices = topk_indices // num_classes\n                labels_per_img = topk_indices % num_classes\n                masks_per_img = scaled_mask_preds[img_id][mask_indices]\n                single_result = self.mask_head[-1].get_seg_masks(\n                    masks_per_img, labels_per_img, scores_per_img,\n                    self.test_cfg, img_metas[img_id])\n                results.append(single_result)\n        return results\n\n    def simple_test_mask_preds(self,\n                    x,\n                    proposal_feats,\n                    mask_preds,\n                    cls_score,\n                    img_metas,\n                    imgs_whwh=None,\n                    rescale=False):\n\n        # Decode initial proposals\n        num_imgs = len(img_metas)\n        # num_proposals = proposal_feats.size(1)\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n        return object_feats, cls_score, mask_preds, scaled_mask_preds\n\n\n    def aug_test(self, features, proposal_list, img_metas, rescale=False):\n        raise NotImplementedError('SparseMask does not support `aug_test`')\n\n    def forward_dummy(self, x, proposal_boxes, proposal_feats, img_metas):\n        \"\"\"Dummy forward function when do the flops computing.\"\"\"\n        all_stage_mask_results = []\n        num_imgs = len(img_metas)\n        num_proposals = proposal_feats.size(1)\n        C, H, W = x.shape[-3:]\n        mask_preds = proposal_feats.bmm(x.view(num_imgs, C, -1)).view(\n            num_imgs, num_proposals, H, W)\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n        return all_stage_mask_results\n\n    def get_panoptic(self, cls_scores, mask_preds, test_cfg, img_meta):\n        # resize mask predictions back\n        thing_scores = cls_scores[:self.num_proposals][:, :self.\n                                                       num_thing_classes]\n        thing_mask_preds = mask_preds[:self.num_proposals]\n        thing_scores, topk_indices = thing_scores.flatten(0, 1).topk(\n            self.test_cfg.max_per_img, sorted=True)\n        mask_indices = topk_indices // self.num_thing_classes\n        thing_labels = topk_indices % self.num_thing_classes\n        masks_per_img = thing_mask_preds[mask_indices]\n        thing_masks = self.mask_head[-1].rescale_masks(masks_per_img, img_meta)\n\n        if not self.merge_joint:\n            thing_masks = thing_masks > test_cfg.mask_thr\n        bbox_result, segm_result = self.mask_head[-1].segm2result(\n            thing_masks, thing_labels, thing_scores)\n\n        stuff_scores = cls_scores[\n            self.num_proposals:][:, self.num_thing_classes:].diag()\n        stuff_scores, stuff_inds = torch.sort(stuff_scores, descending=True)\n        stuff_masks = mask_preds[self.num_proposals:][stuff_inds]\n        stuff_masks = self.mask_head[-1].rescale_masks(stuff_masks, img_meta)\n\n        if not self.merge_joint:\n            stuff_masks = stuff_masks > test_cfg.mask_thr\n\n        if self.merge_joint:\n            stuff_labels = stuff_inds + self.num_thing_classes\n            panoptic_result = self.merge_stuff_thing_stuff_joint(thing_masks, thing_labels,\n                                                     thing_scores, stuff_masks,\n                                                     stuff_labels, stuff_scores,\n                                                     test_cfg.merge_stuff_thing)\n        else:\n            stuff_labels = stuff_inds + 1\n            panoptic_result = self.merge_stuff_thing(thing_masks, thing_labels,\n                                                     thing_scores, stuff_masks,\n                                                     stuff_labels, stuff_scores,\n                                                     test_cfg.merge_stuff_thing)\n        return bbox_result, segm_result, panoptic_result\n\n    def split_thing_stuff(self, mask_preds, det_labels, cls_scores):\n        thing_scores = cls_scores[:self.num_proposals]\n        thing_masks = mask_preds[:self.num_proposals]\n        thing_labels = det_labels[:self.num_proposals]\n\n        stuff_labels = det_labels[self.num_proposals:]\n        stuff_labels = stuff_labels - self.num_thing_classes + 1\n        stuff_masks = mask_preds[self.num_proposals:]\n        stuff_scores = cls_scores[self.num_proposals:]\n\n        results = (thing_masks, thing_labels, thing_scores, stuff_masks,\n                   stuff_labels, stuff_scores)\n        return results\n\n    def merge_stuff_thing(self,\n                          thing_masks,\n                          thing_labels,\n                          thing_scores,\n                          stuff_masks,\n                          stuff_labels,\n                          stuff_scores,\n                          merge_cfg=None):\n\n        H, W = thing_masks.shape[-2:]\n        panoptic_seg = thing_masks.new_zeros((H, W), dtype=torch.int32)\n        thing_masks = thing_masks.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n        stuff_masks = stuff_masks.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-thing_scores)\n        current_segment_id = 0\n        segments_info = []\n        # Add instances one-by-one, check for overlaps with existing ones\n        for inst_id in sorted_inds:\n            score = thing_scores[inst_id].item()\n            if score < merge_cfg.instance_score_thr:\n                break\n            mask = thing_masks[inst_id]  # H,W\n            mask_area = mask.sum().item()\n\n            if mask_area == 0:\n                continue\n\n            intersect = (mask > 0) & (panoptic_seg > 0)\n            intersect_area = intersect.sum().item()\n\n            if intersect_area * 1.0 / mask_area > merge_cfg.iou_thr:\n                continue\n\n            if intersect_area > 0:\n                mask = mask & (panoptic_seg == 0)\n\n            mask_area = mask.sum().item()\n            if mask_area == 0:\n                continue\n\n            current_segment_id += 1\n            panoptic_seg[mask.bool()] = current_segment_id\n            segments_info.append({\n                'id': current_segment_id,\n                'isthing': True,\n                'score': score,\n                'category_id': thing_labels[inst_id].item(),\n                'instance_id': inst_id.item(),\n            })\n\n        # Add semantic results to remaining empty areas\n        sorted_inds = torch.argsort(-stuff_scores)\n        sorted_stuff_labels = stuff_labels[sorted_inds]\n        # paste semantic masks following the order of scores\n        processed_label = []\n        for semantic_label in sorted_stuff_labels:\n            semantic_label = semantic_label.item()\n            if semantic_label in processed_label:\n                continue\n            processed_label.append(semantic_label)\n            sem_inds = stuff_labels == semantic_label\n            sem_masks = stuff_masks[sem_inds].sum(0).bool()\n            mask = sem_masks & (panoptic_seg == 0)\n            mask_area = mask.sum().item()\n            if mask_area < merge_cfg.stuff_max_area:\n                continue\n\n            current_segment_id += 1\n            panoptic_seg[mask] = current_segment_id\n            segments_info.append({\n                'id': current_segment_id,\n                'isthing': False,\n                'category_id': semantic_label,\n                'area': mask_area,\n            })\n        return panoptic_seg.cpu().numpy(), segments_info\n\n    def merge_stuff_thing_stuff_joint(self,\n                                      thing_masks,\n                                      thing_labels,\n                                      thing_scores,\n                                      stuff_masks,\n                                      stuff_labels,\n                                      stuff_scores,\n                                      merge_cfg=None):\n\n        H, W = thing_masks.shape[-2:]\n        panoptic_seg = thing_masks.new_zeros((H, W), dtype=torch.int32)\n\n        total_masks = torch.cat([thing_masks, stuff_masks], dim=0)\n        total_scores = torch.cat([thing_scores, stuff_scores], dim=0)\n        total_labels = torch.cat([thing_labels, stuff_labels], dim=0)\n\n        cur_prob_masks = total_scores.view(-1, 1, 1) * total_masks\n        segments_info = []\n        cur_mask_ids = cur_prob_masks.argmax(0)\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-total_scores)\n        current_segment_id = 0\n\n        for k in sorted_inds:\n            pred_class = total_labels[k].item()\n            isthing = pred_class < self.num_thing_classes\n            if isthing and total_scores[k] < merge_cfg.instance_score_thr:\n                continue\n\n            mask = cur_mask_ids == k\n            mask_area = mask.sum().item()\n            original_area = (total_masks[k] >= 0.5).sum().item()\n\n            if mask_area > 0 and original_area > 0:\n                if mask_area / original_area < merge_cfg.overlap_thr:\n                    continue\n                current_segment_id += 1\n\n                panoptic_seg[mask] = current_segment_id\n\n                if isthing:\n                    segments_info.append({\n                        'id': current_segment_id,\n                        'isthing': isthing,\n                        'score': total_scores[k].item(),\n                        'category_id': pred_class,\n                        'instance_id': k.item(),\n                    })\n                else:\n                    segments_info.append({\n                        'id': current_segment_id,\n                        'isthing': isthing,\n                        'category_id': pred_class - self.num_thing_classes + 1,\n                        'area': mask_area,\n                    })\n\n        return panoptic_seg.cpu().numpy(), segments_info\n"
  },
  {
    "path": "knet/det/kernel_update_head.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob,\n                      build_activation_layer, build_norm_layer)\nfrom mmcv.runner import force_fp32\nfrom mmdet.core import multi_apply\nfrom mmdet.models.builder import HEADS, build_loss\nfrom mmdet.models.dense_heads.atss_head import reduce_mean\nfrom mmdet.models.losses import accuracy\nfrom mmcv.cnn.bricks.transformer import FFN, MultiheadAttention, build_transformer_layer\nfrom mmdet.utils import get_root_logger\n\n\n@HEADS.register_module()\nclass KernelUpdateHead(nn.Module):\n\n    def __init__(self,\n                 num_classes=80,\n                 num_ffn_fcs=2,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_mask_fcs=3,\n                 feedforward_channels=2048,\n                 in_channels=256,\n                 out_channels=256,\n                 dropout=0.0,\n                 mask_thr=0.5,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 ffn_act_cfg=dict(type='ReLU', inplace=True),\n                 conv_kernel_size=3,\n                 feat_transform_cfg=None,\n                 hard_mask_thr=0.5,\n                 kernel_init=False,\n                 with_ffn=True,\n                 mask_out_stride=4,\n                 relative_coors=False,\n                 relative_coors_off=False,\n                 feat_gather_stride=1,\n                 mask_transform_stride=1,\n                 mask_upsample_stride=1,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 kernel_updator_cfg=dict(\n                     type='DynamicConv',\n                     in_channels=256,\n                     feat_channels=64,\n                     out_channels=256,\n                     input_feat_shape=1,\n                     act_cfg=dict(type='ReLU', inplace=True),\n                     norm_cfg=dict(type='LN')),\n                 loss_rank=None,\n                 loss_mask=dict(\n                     type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),\n                 loss_dice=dict(type='DiceLoss', loss_weight=3.0),\n                 loss_cls=dict(\n                     type='FocalLoss',\n                     use_sigmoid=True,\n                     gamma=2.0,\n                     alpha=0.25,\n                     loss_weight=2.0)):\n        super(KernelUpdateHead, self).__init__()\n        self.num_classes = num_classes\n        self.loss_cls = build_loss(loss_cls)\n        self.loss_mask = build_loss(loss_mask)\n        self.loss_dice = build_loss(loss_dice)\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.mask_thr = mask_thr\n        self.fp16_enabled = False\n        self.dropout = dropout\n\n        self.num_heads = num_heads\n        self.hard_mask_thr = hard_mask_thr\n        self.kernel_init = kernel_init\n        self.with_ffn = with_ffn\n        self.mask_out_stride = mask_out_stride\n        self.relative_coors = relative_coors\n        self.relative_coors_off = relative_coors_off\n        self.conv_kernel_size = conv_kernel_size\n        self.feat_gather_stride = feat_gather_stride\n        self.mask_transform_stride = mask_transform_stride\n        self.mask_upsample_stride = mask_upsample_stride\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n\n        self.attention = MultiheadAttention(\n            in_channels * conv_kernel_size**2, num_heads, dropout)\n        self.attention_norm = build_norm_layer(\n            dict(type='LN'), in_channels * conv_kernel_size**2)[1]\n\n        self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)\n\n        if feat_transform_cfg is not None:\n            kernel_size = feat_transform_cfg.pop('kernel_size', 1)\n            self.feat_transform = ConvModule(\n                in_channels,\n                in_channels,\n                kernel_size,\n                stride=feat_gather_stride,\n                padding=int(feat_gather_stride // 2),\n                **feat_transform_cfg)\n        else:\n            self.feat_transform = None\n\n        if self.with_ffn:\n            self.ffn = FFN(\n                in_channels,\n                feedforward_channels,\n                num_ffn_fcs,\n                act_cfg=ffn_act_cfg,\n                dropout=dropout)\n            self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n        self.cls_fcs = nn.ModuleList()\n        for _ in range(num_cls_fcs):\n            self.cls_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.cls_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.cls_fcs.append(build_activation_layer(act_cfg))\n\n        if self.loss_cls.use_sigmoid:\n            self.fc_cls = nn.Linear(in_channels, self.num_classes)\n        else:\n            self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)\n\n        self.mask_fcs = nn.ModuleList()\n        for _ in range(num_mask_fcs):\n            self.mask_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.mask_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.mask_fcs.append(build_activation_layer(act_cfg))\n\n        self.fc_mask = nn.Linear(in_channels, out_channels)\n\n    def init_weights(self):\n        \"\"\"Use xavier initialization for all weight parameter and set\n        classification head bias as a specific value when use focal loss.\"\"\"\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n            else:\n                # adopt the default initialization for\n                # the weight and bias of the layer norm\n                pass\n        if self.loss_cls.use_sigmoid:\n            bias_init = bias_init_with_prob(0.01)\n            nn.init.constant_(self.fc_cls.bias, bias_init)\n        if self.kernel_init:\n            logger = get_root_logger()\n            logger.info(\n                'mask kernel in mask head is normal initialized by std 0.01')\n            nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)\n\n    def forward(self,\n                x,\n                proposal_feat,\n                mask_preds,\n                prev_cls_score=None,\n                mask_shape=None,\n                img_metas=None):\n\n        N, num_proposals = proposal_feat.shape[:2]\n        if self.feat_transform is not None:\n            x = self.feat_transform(x)\n        C, H, W = x.shape[-3:]\n\n        mask_h, mask_w = mask_preds.shape[-2:]\n        if mask_h != H or mask_w != W:\n            gather_mask = F.interpolate(\n                mask_preds, (H, W), align_corners=False, mode='bilinear')\n        else:\n            gather_mask = mask_preds\n\n        sigmoid_masks = gather_mask.sigmoid()\n        nonzero_inds = sigmoid_masks > self.hard_mask_thr\n        sigmoid_masks = nonzero_inds.float()\n\n        # einsum is faster than bmm by 30%\n        x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)\n\n        # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]\n        proposal_feat = proposal_feat.reshape(N, num_proposals,\n                                              self.in_channels,\n                                              -1).permute(0, 1, 3, 2)\n        obj_feat = self.kernel_update_conv(x_feat, proposal_feat)\n\n        # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]\n        obj_feat = obj_feat.reshape(N, num_proposals,\n                                    -1).permute(1, 0, 2)\n        obj_feat = self.attention_norm(self.attention(obj_feat))\n        # [N, B, K*K*C] -> [B, N, K*K*C]\n        obj_feat = obj_feat.permute(1, 0, 2)\n\n        # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n        obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n\n        # FFN\n        if self.with_ffn:\n            obj_feat = self.ffn_norm(self.ffn(obj_feat))\n\n        cls_feat = obj_feat.sum(-2)\n        mask_feat = obj_feat\n\n        for cls_layer in self.cls_fcs:\n            cls_feat = cls_layer(cls_feat)\n        for reg_layer in self.mask_fcs:\n            mask_feat = reg_layer(mask_feat)\n\n        cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)\n        # [B, N, K*K, C] -> [B, N, C, K*K]\n        mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)\n\n\n        if (self.mask_transform_stride == 2\n                and self.feat_gather_stride == 1):\n            mask_x = F.interpolate(\n                x, scale_factor=0.5, mode='bilinear', align_corners=False)\n            H, W = mask_x.shape[-2:]\n        else:\n            mask_x = x\n        # group conv is 5x faster than unfold and uses about 1/5 memory\n        # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms\n        # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369\n        # fold_x = F.unfold(\n        #     mask_x,\n        #     self.conv_kernel_size,\n        #     padding=int(self.conv_kernel_size // 2))\n        # mask_feat = mask_feat.reshape(N, num_proposals, -1)\n        # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)\n        # [B, N, C, K*K] -> [B*N, C, K, K]\n        mask_feat = mask_feat.reshape(N, num_proposals, C,\n                                      self.conv_kernel_size,\n                                      self.conv_kernel_size)\n        # [B, C, H, W] -> [1, B*C, H, W]\n        new_mask_preds = []\n        for i in range(N):\n            new_mask_preds.append(\n                F.conv2d(\n                    mask_x[i:i + 1],\n                    mask_feat[i],\n                    padding=int(self.conv_kernel_size // 2)))\n\n        new_mask_preds = torch.cat(new_mask_preds, dim=0)\n        new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)\n        if self.mask_transform_stride == 2:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                scale_factor=2,\n                mode='bilinear',\n                align_corners=False)\n\n        if mask_shape is not None and mask_shape[0] != H:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                mask_shape,\n                align_corners=False,\n                mode='bilinear')\n\n        return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(\n            N, num_proposals, self.in_channels, self.conv_kernel_size,\n            self.conv_kernel_size)\n\n    @force_fp32(apply_to=('cls_score', 'mask_pred'))\n    def loss(self,\n             object_feats,\n             cls_score,\n             mask_pred,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             imgs_whwh=None,\n             reduction_override=None,\n             **kwargs):\n\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_pos = pos_inds.sum().float()\n        avg_factor = reduce_mean(num_pos).clamp_(min=1.0)\n\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n        assert mask_pred.shape[0] == cls_score.shape[0]\n        assert mask_pred.shape[1] == cls_score.shape[1]\n\n        if cls_score is not None:\n            if cls_score.numel() > 0:\n                losses['loss_cls'] = self.loss_cls(\n                    cls_score.view(num_preds, -1),\n                    labels,\n                    label_weights,\n                    avg_factor=avg_factor,\n                    reduction_override=reduction_override)\n                losses['pos_acc'] = accuracy(\n                    cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds])\n        if mask_pred is not None:\n            bool_pos_inds = pos_inds.type(torch.bool)\n            # 0~self.num_classes-1 are FG, self.num_classes is BG\n            # do not perform bounding box regression for BG anymore.\n            H, W = mask_pred.shape[-2:]\n            if pos_inds.any():\n                pos_mask_pred = mask_pred.reshape(num_preds, H,\n                                                  W)[bool_pos_inds]\n                pos_mask_targets = mask_targets[bool_pos_inds]\n                losses['loss_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n                losses['loss_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n                if self.loss_rank is not None:\n                    batch_size = mask_pred.size(0)\n                    rank_target = mask_targets.new_full((batch_size, H, W),\n                                                        self.ignore_label,\n                                                        dtype=torch.long)\n                    rank_inds = pos_inds.view(batch_size,\n                                              -1).nonzero(as_tuple=False)\n                    batch_mask_targets = mask_targets.view(\n                        batch_size, -1, H, W).bool()\n                    for i in range(batch_size):\n                        curr_inds = (rank_inds[:, 0] == i)\n                        curr_rank = rank_inds[:, 1][curr_inds]\n                        for j in curr_rank:\n                            rank_target[i][batch_mask_targets[i][j]] = j\n                    losses['loss_rank'] = self.loss_rank(\n                        mask_pred, rank_target, ignore_index=self.ignore_label)\n            else:\n                losses['loss_mask'] = mask_pred.sum() * 0\n                losses['loss_dice'] = mask_pred.sum() * 0\n                if self.loss_rank is not None:\n                    losses['loss_rank'] = mask_pred.sum() * 0\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros((num_samples, self.num_classes))\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            pos_mask_targets = pos_gt_mask\n            mask_targets[pos_inds, ...] = pos_mask_targets\n            mask_weights[pos_inds, ...] = 1\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            sem_labels = pos_mask.new_full((self.num_stuff_classes, ),\n                                           self.num_classes,\n                                           dtype=torch.long)\n            sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_stuff_weights = torch.eye(\n                self.num_stuff_classes, device=pos_mask.device)\n            sem_thing_weights = pos_mask.new_zeros(\n                (self.num_stuff_classes, self.num_thing_classes))\n            sem_label_weights = torch.cat(\n                [sem_thing_weights, sem_stuff_weights], dim=-1)\n            if len(gt_sem_cls > 0):\n                sem_inds = gt_sem_cls - self.num_thing_classes\n                sem_inds = sem_inds.long()\n                sem_labels[sem_inds] = gt_sem_cls.long()\n                sem_targets[sem_inds] = gt_sem_seg\n                sem_weights[sem_inds] = 1\n\n            label_weights[:, self.num_thing_classes:] = 0\n            labels = torch.cat([labels, sem_labels])\n            label_weights = torch.cat([label_weights, sem_label_weights])\n            mask_targets = torch.cat([mask_targets, sem_targets])\n            mask_weights = torch.cat([mask_weights, sem_weights])\n\n        return labels, label_weights, mask_targets, mask_weights\n\n    def get_targets(self,\n                    sampling_results,\n                    gt_mask,\n                    gt_labels,\n                    rcnn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n  \n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * 2\n            gt_sem_cls = [None] * 2\n\n        labels, label_weights, mask_targets, mask_weights = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rcnn_train_cfg)\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n        return labels, label_weights, mask_targets, mask_weights\n\n    def rescale_masks(self, masks_per_img, img_meta):\n        h, w, _ = img_meta['img_shape']\n        masks_per_img = F.interpolate(\n            masks_per_img.unsqueeze(0).sigmoid(),\n            size=img_meta['batch_input_shape'],\n            mode='bilinear',\n            align_corners=False)\n\n        masks_per_img = masks_per_img[:, :, :h, :w]\n        ori_shape = img_meta['ori_shape']\n        seg_masks = F.interpolate(\n            masks_per_img,\n            size=ori_shape[:2],\n            mode='bilinear',\n            align_corners=False).squeeze(0)\n        return seg_masks\n\n    def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img,\n                      test_cfg, img_meta):\n        # resize mask predictions back\n        seg_masks = self.rescale_masks(masks_per_img, img_meta)\n        seg_masks = seg_masks > test_cfg.mask_thr\n        bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img,\n                                                    scores_per_img)\n        return bbox_result, segm_result\n\n    def segm2result(self, mask_preds, det_labels, cls_scores):\n        num_classes = self.num_classes\n        bbox_result = None\n        segm_result = [[] for _ in range(num_classes)]\n        mask_preds = mask_preds.cpu().numpy()\n        det_labels = det_labels.cpu().numpy()\n        cls_scores = cls_scores.cpu().numpy()\n        num_ins = mask_preds.shape[0]\n        # fake bboxes\n        bboxes = np.zeros((num_ins, 5), dtype=np.float32)\n        bboxes[:, -1] = cls_scores\n        bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)]\n        for idx in range(num_ins):\n            segm_result[det_labels[idx]].append(mask_preds[idx])\n        return bbox_result, segm_result\n"
  },
  {
    "path": "knet/det/knet.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import TwoStageDetector\n\nfrom .utils import sem2ins_masks, sem2ins_masks_cityscapes, sem2ins_masks_kitti_step\n\n\n@DETECTORS.register_module()\nclass KNet(TwoStageDetector):\n\n    def __init__(self,\n                 *args,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 kitti_step=False,\n                 **kwargs):\n        super(KNet, self).__init__(*args, **kwargs)\n        assert self.with_rpn, 'KNet does not support external proposals'\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      proposals=None,\n                      gt_semantic_seg=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        super(TwoStageDetector, self).forward_train(img, img_metas)\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                    i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                    i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                elif self.kitti_step:\n                    sem_labels, sem_seg = sem2ins_masks_kitti_step(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=2,\n                        thing_label_in_seg=(11, 13))\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        gt_masks = gt_masks_tensor\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        losses = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        losses.update(rpn_losses)\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        segm_results = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            imgs_whwh=None,\n            rescale=rescale)\n        if self.kitti_step:\n            res = segm_results[0]\n            segm_results[0] = (*res, None, None)\n        return segm_results\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(3, *img.shape[-2:])) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        # roi_head\n        roi_outs = self.roi_head.simple_test_mask_preds(x_feats, proposal_feats, mask_preds, cls_scores, dummy_img_metas)\n        return roi_outs\n"
  },
  {
    "path": "knet/det/mask_hungarian_assigner.py",
    "content": "import numpy as np\nimport torch\n\nfrom mmdet.core import AssignResult, BaseAssigner, reduce_mean\nfrom mmdet.core.bbox.builder import BBOX_ASSIGNERS\nfrom mmdet.core.bbox.match_costs.builder import MATCH_COST, build_match_cost\n\ntry:\n    from scipy.optimize import linear_sum_assignment\nexcept ImportError:\n    linear_sum_assignment = None\n\n\n@MATCH_COST.register_module()\nclass DiceCost(object):\n    \"\"\"DiceCost.\n\n     Args:\n         weight (int | float, optional): loss_weight\n         pred_act (bool): Whether to activate the prediction\n            before calculating cost\n\n     Examples:\n         >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost\n         >>> import torch\n         >>> self = BBoxL1Cost()\n         >>> bbox_pred = torch.rand(1, 4)\n         >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])\n         >>> factor = torch.tensor([10, 8, 10, 8])\n         >>> self(bbox_pred, gt_bboxes, factor)\n         tensor([[1.6172, 1.6422]])\n    \"\"\"\n\n    def __init__(self,\n                 weight=1.,\n                 pred_act=False,\n                 act_mode='sigmoid',\n                 eps=1e-3):\n        self.weight = weight\n        self.pred_act = pred_act\n        self.act_mode = act_mode\n        self.eps = eps\n\n    def dice_loss(cls, input, target, eps=1e-3):\n        input = input.reshape(input.size()[0], -1)\n        target = target.reshape(target.size()[0], -1).float()\n        # einsum saves 10x memory\n        # a = torch.sum(input[:, None] * target[None, ...], -1)\n        a = torch.einsum('nh,mh->nm', input, target)\n        b = torch.sum(input * input, 1) + eps\n        c = torch.sum(target * target, 1) + eps\n        d = (2 * a) / (b[:, None] + c[None, ...])\n        # 1 is a constance that will not affect the matching, so ommitted\n        return -d\n\n    def __call__(self, mask_preds, gt_masks):\n        \"\"\"\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            gt_bboxes (Tensor): Ground truth boxes with normalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n\n        Returns:\n            torch.Tensor: bbox_cost value with weight\n        \"\"\"\n        if self.pred_act and self.act_mode == 'sigmoid':\n            mask_preds = mask_preds.sigmoid().clamp(min=0.001, max=1.0)\n        elif self.pred_act:\n            mask_preds = mask_preds.softmax(dim=0)\n        # print(\"mask pred:\", mask_preds)\n        dice_cost = self.dice_loss(mask_preds, gt_masks, self.eps)\n        return dice_cost * self.weight\n\n\n@MATCH_COST.register_module()\nclass MaskCost(object):\n    \"\"\"MaskCost.\n\n    Args:\n        weight (int | float, optional): loss_weight\n    \"\"\"\n\n    def __init__(self, weight=1., pred_act=False, act_mode='sigmoid'):\n        self.weight = weight\n        self.pred_act = pred_act\n        self.act_mode = act_mode\n\n    def __call__(self, cls_pred, target):\n        \"\"\"\n        Args:\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n\n        Returns:\n            torch.Tensor: cls_cost value with weight\n        \"\"\"\n        if self.pred_act and self.act_mode == 'sigmoid':\n            cls_pred = cls_pred.sigmoid().clamp(min=0.01, max=1.0)\n        elif self.pred_act:\n            cls_pred = cls_pred.softmax(dim=0)\n        num_proposals = cls_pred.shape[0]\n        num_gts, H, W = target.shape\n        # flatten_cls_pred = cls_pred.view(num_proposals, -1)\n        # eingum is ~10 times faster than matmul\n        pos_cost = torch.einsum('nhw,mhw->nm', cls_pred, target)\n        neg_cost = torch.einsum('nhw,mhw->nm', 1 - cls_pred, 1 - target)\n        # flatten_target = target.view(num_gts, -1).t()\n        # pos_cost = flatten_cls_pred.matmul(flatten_target)\n        # neg_cost = (1 - flatten_cls_pred).matmul(1 - flatten_target)\n        cls_cost = -(pos_cost + neg_cost) / (H * W)\n        return cls_cost * self.weight\n\n\n@BBOX_ASSIGNERS.register_module()\nclass MaskHungarianAssigner(BaseAssigner):\n    \"\"\"Computes one-to-one matching between predictions and ground truth.\n\n    This class computes an assignment between the targets and the predictions\n    based on the costs. The costs are weighted sum of three components:\n    classfication cost, regression L1 cost and regression iou cost. The\n    targets don't include the no_object, so generally there are more\n    predictions than targets. After the one-to-one matching, the un-matched\n    are treated as backgrounds. Thus each query prediction will be assigned\n    with `0` or a positive integer indicating the ground truth index:\n\n    - 0: negative sample, no assigned gt\n    - positive integer: positive sample, index (1-based) of assigned gt\n\n    Args:\n        cls_weight (int | float, optional): The scale factor for classification\n            cost. Default 1.0.\n        bbox_weight (int | float, optional): The scale factor for regression\n            L1 cost. Default 1.0.\n        iou_weight (int | float, optional): The scale factor for regression\n            iou cost. Default 1.0.\n        iou_calculator (dict | optional): The config for the iou calculation.\n            Default type `BboxOverlaps2D`.\n        iou_mode (str | optional): \"iou\" (intersection over union), \"iof\"\n                (intersection over foreground), or \"giou\" (generalized\n                intersection over union). Default \"giou\".\n    \"\"\"\n\n    def __init__(self,\n                 cls_cost=dict(type='ClassificationCost', weight=1.),\n                 mask_cost=dict(type='SigmoidCost', weight=1.0),\n                 dice_cost=dict(),\n                 boundary_cost=None,\n                 topk=1):\n        self.cls_cost = build_match_cost(cls_cost)\n        self.mask_cost = build_match_cost(mask_cost)\n        self.dice_cost = build_match_cost(dice_cost)\n        if boundary_cost is not None:\n            self.boundary_cost = build_match_cost(boundary_cost)\n        else:\n            self.boundary_cost = None\n        self.topk = topk\n\n    def assign(self,\n               bbox_pred,\n               cls_pred,\n               gt_bboxes,\n               gt_labels,\n               img_meta=None,\n               gt_bboxes_ignore=None,\n               eps=1e-7):\n        \"\"\"Computes one-to-one matching based on the weighted costs.\n\n        This method assign each query prediction to a ground truth or\n        background. The `assigned_gt_inds` with -1 means don't care,\n        0 means negative sample, and positive number is the index (1-based)\n        of assigned gt.\n        The assignment is done in the following steps, the order matters.\n\n        1. assign every prediction to -1\n        2. compute the weighted costs\n        3. do Hungarian matching on CPU based on the costs\n        4. assign all to 0 (background) first, then for each matched pair\n           between predictions and gts, treat this prediction as foreground\n           and assign the corresponding gt index (plus 1) to it.\n\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_bboxes (Tensor): Ground truth boxes with unnormalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n            img_meta (dict): Meta information for current image.\n            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are\n                labelled as `ignored`. Default None.\n            eps (int | float, optional): A value added to the denominator for\n                numerical stability. Default 1e-7.\n\n        Returns:\n            :obj:`AssignResult`: The assigned result.\n        \"\"\"\n        assert gt_bboxes_ignore is None, \\\n            'Only case when gt_bboxes_ignore is None is supported.'\n        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)\n\n        # 1. assign -1 by default\n        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),\n                                              -1,\n                                              dtype=torch.long)\n        assigned_labels = bbox_pred.new_full((num_bboxes, ),\n                                             -1,\n                                             dtype=torch.long)\n        assigned_instance_ids = bbox_pred.new_full((num_bboxes,),\n                                             -1,\n                                             dtype=torch.long)\n        if num_gts == 0 or num_bboxes == 0:\n            # No ground truth or boxes, return empty assignment\n            if num_gts == 0:\n                # No ground truth, assign all to background\n                assigned_gt_inds[:] = 0\n            return AssignResult(\n                num_gts, assigned_gt_inds, None, labels=assigned_labels)\n\n        # 2. compute the weighted costs\n        # classification and bboxcost.\n        if self.cls_cost.weight != 0 and cls_pred is not None:\n            cls_cost = self.cls_cost(cls_pred, gt_labels)\n        else:\n            cls_cost = 0\n        if self.mask_cost.weight != 0:\n            reg_cost = self.mask_cost(bbox_pred, gt_bboxes)\n        else:\n            reg_cost = 0\n        if self.dice_cost.weight != 0:\n            dice_cost = self.dice_cost(bbox_pred, gt_bboxes)\n        else:\n            dice_cost = 0\n        if self.boundary_cost is not None and self.boundary_cost.weight != 0:\n            b_cost = self.boundary_cost(bbox_pred, gt_bboxes)\n        else:\n            b_cost = 0\n        cost = cls_cost + reg_cost + dice_cost + b_cost\n\n\n        # 3. do Hungarian matching on CPU using linear_sum_assignment\n        cost = cost.detach().cpu()\n        if linear_sum_assignment is None:\n            raise ImportError('Please run \"pip install scipy\" '\n                              'to install scipy first.')\n        if self.topk == 1:\n            matched_row_inds, matched_col_inds = linear_sum_assignment(cost)\n        else:\n            topk_matched_row_inds = []\n            topk_matched_col_inds = []\n            for i in range(self.topk):\n                matched_row_inds, matched_col_inds = linear_sum_assignment(\n                    cost)\n                topk_matched_row_inds.append(matched_row_inds)\n                topk_matched_col_inds.append(matched_col_inds)\n                cost[matched_row_inds] = 1e10\n            matched_row_inds = np.concatenate(topk_matched_row_inds)\n            matched_col_inds = np.concatenate(topk_matched_col_inds)\n\n        matched_row_inds = torch.from_numpy(matched_row_inds).to(\n            bbox_pred.device)\n        matched_col_inds = torch.from_numpy(matched_col_inds).to(\n            bbox_pred.device)\n\n        # 4. assign backgrounds and foregrounds\n        # assign all indices to backgrounds first\n        assigned_gt_inds[:] = 0\n        # assign foregrounds based on matching results\n        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1\n        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]\n\n        return AssignResult(\n        num_gts, assigned_gt_inds, None, labels=assigned_labels)\n"
  },
  {
    "path": "knet/det/mask_pseudo_sampler.py",
    "content": "import torch\n\nfrom mmdet.core.bbox import BaseSampler, SamplingResult\nfrom mmdet.core.bbox.builder import BBOX_SAMPLERS\n\n\nclass MaskSamplingResult(SamplingResult):\n    \"\"\"Bbox sampling result.\n\n    Example:\n        >>> # xdoctest: +IGNORE_WANT\n        >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA\n        >>> self = SamplingResult.random(rng=10)\n        >>> print(f'self = {self}')\n        self = <SamplingResult({\n            'neg_masks': torch.Size([12, 4]),\n            'neg_inds': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12]),\n            'num_gts': 4,\n            'pos_assigned_gt_inds': tensor([], dtype=torch.int64),\n            'pos_masks': torch.Size([0, 4]),\n            'pos_inds': tensor([], dtype=torch.int64),\n            'pos_is_gt': tensor([], dtype=torch.uint8)\n        })>\n    \"\"\"\n\n    def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result,\n                 gt_flags):\n        self.pos_inds = pos_inds\n        self.neg_inds = neg_inds\n        self.pos_masks = masks[pos_inds]\n        self.neg_masks = masks[neg_inds]\n        self.pos_is_gt = gt_flags[pos_inds]\n\n        self.num_gts = gt_masks.shape[0]\n        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1\n\n        if gt_masks.numel() == 0:\n            # hack for index error case\n            assert self.pos_assigned_gt_inds.numel() == 0\n            self.pos_gt_masks = torch.empty_like(gt_masks)\n        else:\n            self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]\n\n        if assign_result.labels is not None:\n            self.pos_gt_labels = assign_result.labels[pos_inds]\n        else:\n            self.pos_gt_labels = None\n\n        if \"pids\" in assign_result._extra_properties.keys():\n            self.pos_gt_pids = assign_result._extra_properties['pids'][pos_inds]\n        else:\n            self.pos_gt_pids = None\n\n    @property\n    def masks(self):\n        \"\"\"torch.Tensor: concatenated positive and negative boxes\"\"\"\n        return torch.cat([self.pos_masks, self.neg_masks])\n\n    def __nice__(self):\n        data = self.info.copy()\n        data['pos_masks'] = data.pop('pos_masks').shape\n        data['neg_masks'] = data.pop('neg_masks').shape\n        parts = [f\"'{k}': {v!r}\" for k, v in sorted(data.items())]\n        body = '    ' + ',\\n    '.join(parts)\n        return '{\\n' + body + '\\n}'\n\n    @property\n    def info(self):\n        \"\"\"Returns a dictionary of info about the object.\"\"\"\n        return {\n            'pos_inds': self.pos_inds,\n            'neg_inds': self.neg_inds,\n            'pos_masks': self.pos_masks,\n            'neg_masks': self.neg_masks,\n            'pos_is_gt': self.pos_is_gt,\n            'num_gts': self.num_gts,\n            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,\n        }\n\n\nclass MaskSamplingResultWithScore(SamplingResult):\n    \"\"\"Bbox sampling result.\n\n    Example:\n        >>> # xdoctest: +IGNORE_WANT\n        >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA\n        >>> self = SamplingResult.random(rng=10)\n        >>> print(f'self = {self}')\n        self = <SamplingResult({\n            'neg_masks': torch.Size([12, 4]),\n            'neg_inds': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12]),\n            'num_gts': 4,\n            'pos_assigned_gt_inds': tensor([], dtype=torch.int64),\n            'pos_masks': torch.Size([0, 4]),\n            'pos_inds': tensor([], dtype=torch.int64),\n            'pos_is_gt': tensor([], dtype=torch.uint8)\n        })>\n    \"\"\"\n\n    def __init__(self, pos_inds, neg_inds, masks, scores, gt_masks, assign_result,\n                 gt_flags):\n        self.pos_inds = pos_inds\n        self.neg_inds = neg_inds\n        self.pos_masks = masks[pos_inds]\n        self.neg_masks = masks[neg_inds]\n\n        self.pos_scores = scores[pos_inds]\n        self.neg_scores = scores[neg_inds]\n\n        self.pos_is_gt = gt_flags[pos_inds]\n\n        self.num_gts = gt_masks.shape[0]\n        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1\n\n        if gt_masks.numel() == 0:\n            # hack for index error case\n            assert self.pos_assigned_gt_inds.numel() == 0\n            self.pos_gt_masks = torch.empty_like(gt_masks)\n        else:\n            self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]\n\n        if assign_result.labels is not None:\n            self.pos_gt_labels = assign_result.labels[pos_inds]\n        else:\n            self.pos_gt_labels = None\n\n        if \"pids\" in assign_result._extra_properties.keys():\n            self.pos_gt_pids = assign_result._extra_properties['pids'][pos_inds]\n        else:\n            self.pos_gt_pids = None\n\n    @property\n    def masks(self):\n        \"\"\"torch.Tensor: concatenated positive and negative boxes\"\"\"\n        return torch.cat([self.pos_masks, self.neg_masks])\n\n    def __nice__(self):\n        data = self.info.copy()\n        data['pos_masks'] = data.pop('pos_masks').shape\n        data['neg_masks'] = data.pop('neg_masks').shape\n        parts = [f\"'{k}': {v!r}\" for k, v in sorted(data.items())]\n        body = '    ' + ',\\n    '.join(parts)\n        return '{\\n' + body + '\\n}'\n\n    @property\n    def info(self):\n        \"\"\"Returns a dictionary of info about the object.\"\"\"\n        return {\n            'pos_inds': self.pos_inds,\n            'neg_inds': self.neg_inds,\n            'pos_masks': self.pos_masks,\n            'neg_masks': self.neg_masks,\n            'pos_is_gt': self.pos_is_gt,\n            'num_gts': self.num_gts,\n            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,\n        }\n\n@BBOX_SAMPLERS.register_module()\nclass MaskPseudoSampler(BaseSampler):\n    \"\"\"A pseudo sampler that does not do sampling actually.\"\"\"\n\n    def __init__(self, **kwargs):\n        pass\n\n    def _sample_pos(self, **kwargs):\n        \"\"\"Sample positive samples.\"\"\"\n        raise NotImplementedError\n\n    def _sample_neg(self, **kwargs):\n        \"\"\"Sample negative samples.\"\"\"\n        raise NotImplementedError\n\n    def sample(self, assign_result, masks, gt_masks, **kwargs):\n        \"\"\"Directly returns the positive and negative indices  of samples.\n\n        Args:\n            assign_result (:obj:`AssignResult`): Assigned results\n            masks (torch.Tensor): Bounding boxes\n            gt_masks (torch.Tensor): Ground truth boxes\n\n        Returns:\n            :obj:`SamplingResult`: sampler results\n        \"\"\"\n        pos_inds = torch.nonzero(\n            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()\n        neg_inds = torch.nonzero(\n            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()\n        gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)\n        sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks,\n                                             gt_masks, assign_result, gt_flags)\n        return sampling_result\n\n\n@BBOX_SAMPLERS.register_module()\nclass MaskScorePseudoSampler(BaseSampler):\n    \"\"\"A pseudo sampler that does not do sampling actually.\"\"\"\n\n    def __init__(self, **kwargs):\n        pass\n\n    def _sample_pos(self, **kwargs):\n        \"\"\"Sample positive samples.\"\"\"\n        raise NotImplementedError\n\n    def _sample_neg(self, **kwargs):\n        \"\"\"Sample negative samples.\"\"\"\n        raise NotImplementedError\n\n    def sample(self, assign_result, masks, score, gt_masks, **kwargs):\n        \"\"\"Directly returns the positive and negative indices  of samples.\n\n        Args:\n            assign_result (:obj:`AssignResult`): Assigned results\n            masks (torch.Tensor): Bounding boxes\n            gt_masks (torch.Tensor): Ground truth boxes\n\n        Returns:\n            :obj:`SamplingResult`: sampler results\n        \"\"\"\n        pos_inds = torch.nonzero(\n            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()\n        neg_inds = torch.nonzero(\n            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()\n        gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)\n        sampling_result = MaskSamplingResultWithScore(pos_inds, neg_inds, masks, score,\n                                             gt_masks, assign_result, gt_flags)\n        return sampling_result"
  },
  {
    "path": "knet/det/msdeformattn_decoder.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (Conv2d, ConvModule, caffe2_xavier_init,\n                      normal_init, xavier_init)\nfrom mmdet.models.builder import NECKS\n\nfrom mmcv.cnn.bricks.transformer import (build_positional_encoding,\n                                         build_transformer_layer_sequence)\nfrom mmcv.runner import BaseModule, ModuleList\n\nfrom mmdet.core.anchor import MlvlPointGenerator\nfrom mmdet.models.utils.transformer import MultiScaleDeformableAttention\n\n\n@NECKS.register_module()\nclass MSDeformAttnPixelDecoder(BaseModule):\n    \"\"\"Pixel decoder with multi-scale deformable attention.\n\n    Args:\n        in_channels (list[int] | tuple[int]): Number of channels in the\n            input feature maps.\n        strides (list[int] | tuple[int]): Output strides of feature from\n            backbone.\n        feat_channels (int): Number of channels for feature.\n        out_channels (int): Number of channels for output.\n        num_outs (int): Number of output scales.\n        norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.\n            Defaults to dict(type='GN', num_groups=32).\n        act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.\n            Defaults to dict(type='ReLU').\n        encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer\n            encoder. Defaults to `DetrTransformerEncoder`.\n        positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for\n            transformer encoder position encoding. Defaults to\n            dict(type='SinePositionalEncoding', num_feats=128,\n            normalize=True).\n        init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels=[256, 512, 1024, 2048],\n                 strides=[4, 8, 16, 32],\n                 feat_channels=256,\n                 out_channels=256,\n                 num_outs=3,\n                 return_one_list=True,\n                 norm_cfg=dict(type='GN', num_groups=32),\n                 act_cfg=dict(type='ReLU'),\n                 encoder=dict(\n                     type='DetrTransformerEncoder',\n                     num_layers=6,\n                     transformerlayers=dict(\n                         type='BaseTransformerLayer',\n                         attn_cfgs=dict(\n                             type='MultiScaleDeformableAttention',\n                             embed_dims=256,\n                             num_heads=8,\n                             num_levels=3,\n                             num_points=4,\n                             im2col_step=64,\n                             dropout=0.0,\n                             batch_first=False,\n                             norm_cfg=None,\n                             init_cfg=None),\n                         feedforward_channels=1024,\n                         ffn_dropout=0.0,\n                         operation_order=('self_attn', 'norm', 'ffn', 'norm')),\n                     init_cfg=None),\n                 positional_encoding=dict(\n                     type='SinePositionalEncoding',\n                     num_feats=128,\n                     normalize=True),\n                 init_cfg=None):\n        super().__init__(init_cfg=init_cfg)\n        self.strides = strides\n        self.num_input_levels = len(in_channels)\n        self.return_one_list = return_one_list\n        self.num_encoder_levels = \\\n            encoder.transformerlayers.attn_cfgs.num_levels\n        assert self.num_encoder_levels >= 1, \\\n            'num_levels in attn_cfgs must be at least one'\n        input_conv_list = []\n        # from top to down (low to high resolution)\n        for i in range(self.num_input_levels - 1,\n                       self.num_input_levels - self.num_encoder_levels - 1,\n                       -1):\n            input_conv = ConvModule(\n                in_channels[i],\n                feat_channels,\n                kernel_size=1,\n                norm_cfg=norm_cfg,\n                act_cfg=None,\n                bias=True)\n            input_conv_list.append(input_conv)\n        self.input_convs = ModuleList(input_conv_list)\n\n        self.encoder = build_transformer_layer_sequence(encoder)\n        self.postional_encoding = build_positional_encoding(\n            positional_encoding)\n        # high resolution to low resolution\n        self.level_encoding = nn.Embedding(self.num_encoder_levels,\n                                           feat_channels)\n\n        # fpn-like structure\n        self.lateral_convs = ModuleList()\n        self.output_convs = ModuleList()\n        self.use_bias = norm_cfg is None\n        # from top to down (low to high resolution)\n        # fpn for the rest features that didn't pass in encoder\n        for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,\n                       -1):\n            lateral_conv = ConvModule(\n                in_channels[i],\n                feat_channels,\n                kernel_size=1,\n                bias=self.use_bias,\n                norm_cfg=norm_cfg,\n                act_cfg=None)\n            output_conv = ConvModule(\n                feat_channels,\n                feat_channels,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n                bias=self.use_bias,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n            self.lateral_convs.append(lateral_conv)\n            self.output_convs.append(output_conv)\n\n        self.mask_feature = Conv2d(\n            feat_channels, out_channels, kernel_size=1, stride=1, padding=0)\n\n        self.num_outs = num_outs\n        self.point_generator = MlvlPointGenerator(strides)\n\n    def init_weights(self):\n        \"\"\"Initialize weights.\"\"\"\n        for i in range(0, self.num_encoder_levels):\n            xavier_init(\n                self.input_convs[i].conv,\n                gain=1,\n                bias=0,\n                distribution='uniform')\n\n        for i in range(0, self.num_input_levels - self.num_encoder_levels):\n            caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)\n            caffe2_xavier_init(self.output_convs[i].conv, bias=0)\n\n        caffe2_xavier_init(self.mask_feature, bias=0)\n\n        normal_init(self.level_encoding, mean=0, std=1)\n        for p in self.encoder.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_normal_(p)\n\n        # init_weights defined in MultiScaleDeformableAttention\n        for layer in self.encoder.layers:\n            for attn in layer.attentions:\n                if isinstance(attn, MultiScaleDeformableAttention):\n                    attn.init_weights()\n\n    def forward(self, feats):\n        \"\"\"\n        Args:\n            feats (list[Tensor]): Feature maps of each level. Each has\n                shape of (batch_size, c, h, w).\n\n        Returns:\n            tuple: A tuple containing the following:\n\n            - mask_feature (Tensor): shape (batch_size, c, h, w).\n            - multi_scale_features (list[Tensor]): Multi scale \\\n                    features, each in shape (batch_size, c, h, w).\n        \"\"\"\n        # generate padding mask for each level, for each image\n        batch_size = feats[0].shape[0]\n        encoder_input_list = []\n        padding_mask_list = []\n        level_positional_encoding_list = []\n        spatial_shapes = []\n        reference_points_list = []\n        for i in range(self.num_encoder_levels):\n            level_idx = self.num_input_levels - i - 1\n            feat = feats[level_idx]\n            feat_projected = self.input_convs[i](feat)\n            h, w = feat.shape[-2:]\n\n            # no padding\n            padding_mask_resized = feat.new_zeros(\n                (batch_size, ) + feat.shape[-2:], dtype=torch.bool)\n            pos_embed = self.postional_encoding(padding_mask_resized)\n            level_embed = self.level_encoding.weight[i]\n            level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed\n            # (h_i * w_i, 2)\n            reference_points = self.point_generator.single_level_grid_priors(\n                feat.shape[-2:], level_idx, device=feat.device)\n            # normalize\n            factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]\n            reference_points = reference_points / factor\n\n            # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)\n            feat_projected = feat_projected.flatten(2).permute(2, 0, 1)\n            level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)\n            padding_mask_resized = padding_mask_resized.flatten(1)\n\n            encoder_input_list.append(feat_projected)\n            padding_mask_list.append(padding_mask_resized)\n            level_positional_encoding_list.append(level_pos_embed)\n            spatial_shapes.append(feat.shape[-2:])\n            reference_points_list.append(reference_points)\n        # shape (batch_size, total_num_query),\n        # total_num_query=sum([., h_i * w_i,.])\n        padding_masks = torch.cat(padding_mask_list, dim=1)\n        # shape (total_num_query, batch_size, c)\n        encoder_inputs = torch.cat(encoder_input_list, dim=0)\n        level_positional_encodings = torch.cat(\n            level_positional_encoding_list, dim=0)\n        device = encoder_inputs.device\n        # shape (num_encoder_levels, 2), from low\n        # resolution to high resolution\n        spatial_shapes = torch.as_tensor(\n            spatial_shapes, dtype=torch.long, device=device)\n        # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)\n        level_start_index = torch.cat((spatial_shapes.new_zeros(\n            (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        reference_points = torch.cat(reference_points_list, dim=0)\n        reference_points = reference_points[None, :, None].repeat(\n            batch_size, 1, self.num_encoder_levels, 1)\n        valid_radios = reference_points.new_ones(\n            (batch_size, self.num_encoder_levels, 2))\n        # shape (num_total_query, batch_size, c)\n        memory = self.encoder(\n            query=encoder_inputs,\n            key=None,\n            value=None,\n            query_pos=level_positional_encodings,\n            key_pos=None,\n            attn_masks=None,\n            key_padding_mask=None,\n            query_key_padding_mask=padding_masks,\n            spatial_shapes=spatial_shapes,\n            reference_points=reference_points,\n            level_start_index=level_start_index,\n            valid_radios=valid_radios)\n        # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)\n        memory = memory.permute(1, 2, 0)\n\n        # from low resolution to high resolution\n        num_query_per_level = [e[0] * e[1] for e in spatial_shapes]\n        outs = torch.split(memory, num_query_per_level, dim=-1)\n        outs = [\n            x.reshape(batch_size, -1, spatial_shapes[i][0],\n                      spatial_shapes[i][1]) for i, x in enumerate(outs)\n        ]\n\n        for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,\n                       -1):\n            x = feats[i]\n            cur_feat = self.lateral_convs[i](x)\n            y = cur_feat + F.interpolate(\n                outs[-1],\n                size=cur_feat.shape[-2:],\n                mode='bilinear',\n                align_corners=False)\n            y = self.output_convs[i](y)\n            outs.append(y)\n        multi_scale_features = outs[:self.num_outs]\n\n        mask_feature = self.mask_feature(outs[-1])\n        multi_scale_features.append(mask_feature)\n        multi_scale_features.reverse()\n        return tuple(multi_scale_features)\n"
  },
  {
    "path": "knet/det/semantic_fpn_wrapper.py",
    "content": "import math\n\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom mmcv.cnn import ConvModule, normal_init\nfrom mmdet.models.builder import NECKS, BACKBONES\nfrom mmcv.cnn.bricks.transformer import build_positional_encoding\nfrom mmdet.utils import get_root_logger\nfrom mmcv.ops import DeformConv2dPack\nfrom mmcv.runner import BaseModule\nimport torch.nn.functional as F\n\n\n@NECKS.register_module()\nclass SemanticFPNWrapper(nn.Module):\n    \"\"\"\n    Implementation of Semantic FPN used in Panoptic FPN.\n\n    Args:\n        in_channels ([type]): [description]\n        feat_channels ([type]): [description]\n        out_channels ([type]): [description]\n        start_level ([type]): [description]\n        end_level ([type]): [description]\n        cat_coors (bool, optional): [description]. Defaults to False.\n        fuse_by_cat (bool, optional): [description]. Defaults to False.\n        conv_cfg ([type], optional): [description]. Defaults to None.\n        norm_cfg ([type], optional): [description]. Defaults to None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 feat_channels,\n                 out_channels,\n                 start_level,\n                 end_level,\n                 cat_coors=False,\n                 positional_encoding=None,\n                 cat_coors_level=3,\n                 fuse_by_cat=False,\n                 return_list=False,\n                 upsample_times=3,\n                 with_pred=True,\n                 num_aux_convs=0,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 out_act_cfg=dict(type='ReLU'),\n                 conv_cfg=None,\n                 norm_cfg=None):\n        super(SemanticFPNWrapper, self).__init__()\n\n        self.in_channels = in_channels\n        self.feat_channels = feat_channels\n        self.start_level = start_level\n        self.end_level = end_level\n        assert start_level >= 0 and end_level >= start_level\n        self.out_channels = out_channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.cat_coors = cat_coors\n        self.cat_coors_level = cat_coors_level\n        self.fuse_by_cat = fuse_by_cat\n        self.return_list = return_list\n        self.upsample_times = upsample_times\n        self.with_pred = with_pred\n        if positional_encoding is not None:\n            self.positional_encoding = build_positional_encoding(\n                positional_encoding)\n        else:\n            self.positional_encoding = None\n\n        self.convs_all_levels = nn.ModuleList()\n        for i in range(self.start_level, self.end_level + 1):\n            convs_per_level = nn.Sequential()\n            if i == 0:\n                if i == self.cat_coors_level and self.cat_coors:\n                    chn = self.in_channels + 2\n                else:\n                    chn = self.in_channels\n                if upsample_times == self.end_level - i:\n                    one_conv = ConvModule(\n                        chn,\n                        self.feat_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg,\n                        inplace=False)\n                    convs_per_level.add_module('conv' + str(i), one_conv)\n                else:\n                    for i in range(self.end_level - upsample_times):\n                        one_conv = ConvModule(\n                            chn,\n                            self.feat_channels,\n                            3,\n                            padding=1,\n                            stride=2,\n                            conv_cfg=self.conv_cfg,\n                            norm_cfg=self.norm_cfg,\n                            act_cfg=self.act_cfg,\n                            inplace=False)\n                        convs_per_level.add_module('conv' + str(i), one_conv)\n                self.convs_all_levels.append(convs_per_level)\n                continue\n\n            for j in range(i):\n                if j == 0:\n                    if i == self.cat_coors_level and self.cat_coors:\n                        chn = self.in_channels + 2\n                    else:\n                        chn = self.in_channels\n                    one_conv = ConvModule(\n                        chn,\n                        self.feat_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg,\n                        inplace=False)\n                    convs_per_level.add_module('conv' + str(j), one_conv)\n                    if j < upsample_times - (self.end_level - i):\n                        one_upsample = nn.Upsample(\n                            scale_factor=2,\n                            mode='bilinear',\n                            align_corners=False)\n                        convs_per_level.add_module('upsample' + str(j),\n                                                   one_upsample)\n                    continue\n\n                one_conv = ConvModule(\n                    self.feat_channels,\n                    self.feat_channels,\n                    3,\n                    padding=1,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg,\n                    inplace=False)\n                convs_per_level.add_module('conv' + str(j), one_conv)\n                if j < upsample_times - (self.end_level - i):\n                    one_upsample = nn.Upsample(\n                        scale_factor=2, mode='bilinear', align_corners=False)\n                    convs_per_level.add_module('upsample' + str(j),\n                                               one_upsample)\n\n            self.convs_all_levels.append(convs_per_level)\n\n        if fuse_by_cat:\n            in_channels = self.feat_channels * len(self.convs_all_levels)\n        else:\n            in_channels = self.feat_channels\n\n        if self.with_pred:\n            self.conv_pred = ConvModule(\n                in_channels,\n                self.out_channels,\n                1,\n                padding=0,\n                conv_cfg=self.conv_cfg,\n                act_cfg=out_act_cfg,\n                norm_cfg=self.norm_cfg)\n\n        self.num_aux_convs = num_aux_convs\n        self.aux_convs = nn.ModuleList()\n        for i in range(num_aux_convs):\n            self.aux_convs.append(\n                ConvModule(\n                    in_channels,\n                    self.out_channels,\n                    1,\n                    padding=0,\n                    conv_cfg=self.conv_cfg,\n                    act_cfg=out_act_cfg,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        logger = get_root_logger()\n        logger.info('Use normal intialization for semantic FPN')\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                normal_init(m, std=0.01)\n\n    def generate_coord(self, input_feat):\n        x_range = torch.linspace(\n            -1, 1, input_feat.shape[-1], device=input_feat.device)\n        y_range = torch.linspace(\n            -1, 1, input_feat.shape[-2], device=input_feat.device)\n        y, x = torch.meshgrid(y_range, x_range)\n        y = y.expand([input_feat.shape[0], 1, -1, -1])\n        x = x.expand([input_feat.shape[0], 1, -1, -1])\n        coord_feat = torch.cat([x, y], 1)\n        return coord_feat\n\n    def forward(self, inputs):\n        mlvl_feats = []\n        for i in range(self.start_level, self.end_level + 1):\n            input_p = inputs[i]\n            if i == self.cat_coors_level:\n                if self.positional_encoding is not None:\n                    ignore_mask = input_p.new_zeros(\n                        (input_p.shape[0], input_p.shape[-2],\n                         input_p.shape[-1]),\n                        dtype=torch.bool)\n                    positional_encoding = self.positional_encoding(ignore_mask)\n                    input_p = input_p + positional_encoding\n                if self.cat_coors:\n                    coord_feat = self.generate_coord(input_p)\n                    input_p = torch.cat([input_p, coord_feat], 1)\n\n            mlvl_feats.append(self.convs_all_levels[i](input_p))\n\n        if self.fuse_by_cat:\n            feature_add_all_level = torch.cat(mlvl_feats, dim=1)\n        else:\n            feature_add_all_level = sum(mlvl_feats)\n\n        if self.with_pred:\n            out = self.conv_pred(feature_add_all_level)\n        else:\n            out = feature_add_all_level\n\n        if self.num_aux_convs > 0:\n            outs = [out]\n            for conv in self.aux_convs:\n                outs.append(conv(feature_add_all_level))\n            return outs\n\n        if self.return_list:\n            return [out]\n        else:\n            return out\n\n\n@NECKS.register_module()\nclass UperNetAlignHead(BaseModule):\n\n    def __init__(self, in_channels=[256, 512, 1024, 2048], out_channels=256, feat_channels=256, align_types=\"v1\",\n                 start_level=1, end_level=3, conv3x3_type=\"conv\", positional_encoding=None, cat_coors_level=3,\n                 upsample_times=2, cat_coors=False, fuse_by_cat=False, return_list=False,\n                 num_aux_convs=1, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True) ):\n        super(UperNetAlignHead, self).__init__()\n\n        if positional_encoding is not None:\n            self.positional_encoding = build_positional_encoding(\n                positional_encoding)\n        else:\n            self.positional_encoding = None\n\n        self.cat_coors_level = cat_coors_level\n        self.align_types = align_types\n\n        self.dcn = DeformConv2dPack(in_channels=256, out_channels=out_channels, kernel_size=3, padding=1)\n        self.fpn_in = []\n        for fpn_inplane in in_channels[:-1]:\n            self.fpn_in.append(\n                ConvModule(fpn_inplane, out_channels, kernel_size=1, norm_cfg=dict(type='BN2d'),\n                           act_cfg=dict(type='ReLU'),\n                           inplace=False)\n            )\n        self.fpn_in = nn.ModuleList(self.fpn_in)\n\n        self.fpn_out = []\n        self.fpn_out_align = []\n        self.dsn = []\n        for i in range(len(in_channels) - 1):\n            self.fpn_out.append(\n                ConvModule(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,\n                           norm_cfg=dict(type='BN2d')))\n\n            if conv3x3_type == 'conv':\n                if self.align_types == \"v1\":\n                    self.fpn_out_align.append(\n                        AlignedModule(inplane=out_channels, outplane=out_channels // 2)\n                    )\n                else:\n                    self.fpn_out_align.append(\n                        AlignedModulev2PoolingAtten(inplane=out_channels, outplane=out_channels // 2)\n                    )\n\n            self.fpn_out = nn.ModuleList(self.fpn_out)\n            self.fpn_out_align = nn.ModuleList(self.fpn_out_align)\n\n    def forward(self, conv_out):\n        f = conv_out[-1]\n        fpn_feature_list = [f]\n        for i in reversed(range(len(conv_out) - 1)):\n            conv_x = conv_out[i]\n            conv_x = self.fpn_in[i](conv_x)\n            f = self.fpn_out_align[i]([conv_x, f])\n            f = conv_x + f\n            fpn_feature_list.append(self.fpn_out[i](f))\n\n        output_size = conv_out[1].size()[2:]\n        fusion_list = []\n\n        for i in range(0, len(fpn_feature_list)):\n            fusion_list.append(nn.functional.interpolate(\n                fpn_feature_list[i],\n                output_size,\n                mode='bilinear', align_corners=True))\n\n        x = fusion_list[0]\n        for i in range(1, len(fusion_list)):\n            x += fusion_list[i]\n\n        # add position encodings\n        ignore_mask = x.new_zeros(\n                        (x.shape[0], x.shape[-2],\n                         x.shape[-1]),\n                        dtype=torch.bool)\n        positional_encoding = self.positional_encoding(ignore_mask)\n        x = x + positional_encoding\n\n        return self.dcn(x)\n\n\nclass AlignedModule(nn.Module):\n\n    def __init__(self, inplane, outplane, kernel_size=3):\n        super(AlignedModule, self).__init__()\n        self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False)\n        self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False)\n        self.flow_make = nn.Conv2d(outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False)\n\n    def forward(self, x):\n        low_feature, h_feature = x\n        h_feature_orign = h_feature\n        h, w = low_feature.size()[2:]\n        size = (h, w)\n        low_feature = self.down_l(low_feature)\n        h_feature = self.down_h(h_feature)\n        h_feature = F.interpolate(h_feature, size=size, mode=\"bilinear\", align_corners=True)\n        flow = self.flow_make(torch.cat([h_feature, low_feature], 1))\n        h_feature = self.flow_warp(h_feature_orign, flow, size=size)\n\n        return h_feature\n\n    def flow_warp(self, input, flow, size):\n        out_h, out_w = size\n        n, c, h, w = input.size()\n\n        norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device)\n        h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)\n        w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)\n        grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2)\n        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)\n        grid = grid + flow.permute(0, 2, 3, 1) / norm\n\n        output = F.grid_sample(input, grid, align_corners=True)\n        return output\n\n\nclass AlignedModulev2PoolingAtten(nn.Module):\n\n    def __init__(self, inplane, outplane, kernel_size=3):\n        super(AlignedModulev2PoolingAtten, self).__init__()\n        self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False)\n        self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False)\n        self.flow_make = nn.Conv2d(outplane*2, 4, kernel_size=kernel_size, padding=1, bias=False)\n        self.flow_gate = nn.Sequential(\n            nn.Conv2d(4, 1, kernel_size=kernel_size, padding=1, bias=False),\n            nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        low_feature, h_feature = x\n        h_feature_orign = h_feature\n        h, w = low_feature.size()[2:]\n        size = (h, w)\n        l_feature = self.down_l(low_feature)\n        h_feature = self.down_h(h_feature)\n        h_feature = F.upsample(h_feature, size=size, mode=\"bilinear\", align_corners=True)\n\n        flow = self.flow_make(torch.cat([h_feature, l_feature], 1))\n        flow_up, flow_down = flow[:, :2, :, :], flow[:, 2:, :, :]\n\n        h_feature_warp = self.flow_warp(h_feature_orign, flow_up, size=size)\n        l_feature_warp = self.flow_warp(low_feature, flow_down, size=size)\n\n        h_feature_mean = torch.mean(h_feature, dim=1).unsqueeze(1)\n        l_feature_mean = torch.mean(low_feature, dim=1).unsqueeze(1)\n        h_feature_max = torch.max(h_feature, dim=1)[0].unsqueeze(1)\n        l_feature_max = torch.max(low_feature, dim=1)[0].unsqueeze(1)\n\n        flow_gates = self.flow_gate(torch.cat([h_feature_mean, l_feature_mean, h_feature_max, l_feature_max], 1))\n\n        fuse_feature = h_feature_warp * flow_gates + l_feature_warp * (1 - flow_gates)\n\n        return fuse_feature\n\n    def flow_warp(self, input, flow, size):\n        out_h, out_w = size\n        n, c, h, w = input.size()\n        # n, c, h, w\n        # n, 2, h, w\n\n        norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device)\n        h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)\n        w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)\n        grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2)\n        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)\n        grid = grid + flow.permute(0, 2, 3, 1) / norm\n\n        output = F.grid_sample(input, grid, align_corners=True)\n        return output\n\n\n@BACKBONES.register_module()\nclass STDCNet1446(nn.Module):\n    def __init__(self, base=64, layers=[4, 5, 3], block_num=4, type=\"cat\", num_classes=1000, dropout=0.20,\n                 pretrain_model='./pretrained_models/STDCNet1446_76.47.tar',\n                 use_conv_last=False, norm_layer=nn.SyncBatchNorm, ):\n        super(STDCNet1446, self).__init__()\n        if type == \"cat\":\n            block = CatBottleneck\n        elif type == \"add\":\n            block = AddBottleneck\n        self.use_conv_last = use_conv_last\n        self.features = self._make_layers(base, layers, block_num, block, norm_layer)\n        self.conv_last = ConvX(base * 16, max(1024, base * 16), 1, 1)\n        self.gap = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(max(1024, base * 16), max(1024, base * 16), bias=False)\n        self.bn = nn.BatchNorm1d(max(1024, base * 16))\n        self.relu = nn.ReLU(inplace=True)\n        self.dropout = nn.Dropout(p=dropout)\n        self.linear = nn.Linear(max(1024, base * 16), num_classes, bias=False)\n\n        self.x2 = nn.Sequential(self.features[:1])\n        self.x4 = nn.Sequential(self.features[1:2])\n        self.x8 = nn.Sequential(self.features[2:6])\n        self.x16 = nn.Sequential(self.features[6:11])\n        self.x32 = nn.Sequential(self.features[11:])\n\n        if pretrain_model:\n            print('use pretrain model {}'.format(pretrain_model))\n            self.init_weight(pretrain_model)\n        else:\n            self.init_params()\n\n        self.features = None\n        self.conv_last = None\n        self.gap = None\n        self.fc = None\n        self.bn = None\n        self.relu = None\n        self.dropout = None\n        self.linear = None\n\n    def init_weight(self, pretrain_model):\n\n        state_dict = torch.load(pretrain_model, map_location='cpu')[\"state_dict\"]\n        self_state_dict = self.state_dict()\n        for k, v in state_dict.items():\n            self_state_dict.update({k: v})\n        self.load_state_dict(self_state_dict)\n\n    def init_params(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def _make_layers(self, base, layers, block_num, block, norm_layer):\n        features = []\n        features += [ConvX(3, base // 2, 3, 2)]\n        features += [ConvX(base // 2, base, 3, 2)]\n\n        for i, layer in enumerate(layers):\n            for j in range(layer):\n                if i == 0 and j == 0:\n                    features.append(block(base, base * 4, block_num, 2, norm_layer=norm_layer))\n                elif j == 0:\n                    features.append(block(base * int(math.pow(2, i + 1)), base * int(math.pow(2, i + 2)), block_num, 2,\n                                          norm_layer=norm_layer))\n                else:\n                    features.append(block(base * int(math.pow(2, i + 2)), base * int(math.pow(2, i + 2)), block_num, 1,\n                                          norm_layer=norm_layer))\n\n        return nn.Sequential(*features)\n\n    def forward(self, x):\n        feat2 = self.x2(x)\n        feat4 = self.x4(feat2)\n        feat8 = self.x8(feat4)\n        feat16 = self.x16(feat8)\n        feat32 = self.x32(feat16)\n        if self.use_conv_last:\n            feat32 = self.conv_last(feat32)\n\n        return feat4, feat8, feat16, feat32\n\n\n@BACKBONES.register_module()\nclass STDCNet813(nn.Module):\n    def __init__(self, base=64, layers=[2, 2, 2], block_num=4, type=\"cat\", num_classes=1000, dropout=0.20,\n                 pretrain_model='./pretrained_models/STDCNet813_73.91.tar',\n                 use_conv_last=False, norm_layer=nn.BatchNorm2d):\n        super(STDCNet813, self).__init__()\n        if type == \"cat\":\n            block = CatBottleneck\n        elif type == \"add\":\n            block = AddBottleneck\n        self.use_conv_last = use_conv_last\n        self.features = self._make_layers(base, layers, block_num, block, norm_layer)\n        self.conv_last = ConvX(base * 16, max(1024, base * 16), 1, 1)\n        self.gap = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(max(1024, base * 16), max(1024, base * 16), bias=False)\n        self.bn = nn.BatchNorm1d(max(1024, base * 16))\n        self.relu = nn.ReLU(inplace=True)\n        self.dropout = nn.Dropout(p=dropout)\n        self.linear = nn.Linear(max(1024, base * 16), num_classes, bias=False)\n\n        self.x2 = nn.Sequential(self.features[:1])\n        self.x4 = nn.Sequential(self.features[1:2])\n        self.x8 = nn.Sequential(self.features[2:4])\n        self.x16 = nn.Sequential(self.features[4:6])\n        self.x32 = nn.Sequential(self.features[6:])\n\n        if pretrain_model:\n            print('use pretrain model {}'.format(pretrain_model))\n            self.init_weight(pretrain_model)\n        else:\n            self.init_params()\n\n        self.features = None\n        self.conv_last = None\n        self.gap = None\n        self.fc = None\n        self.bn = None\n        self.relu = None\n        self.dropout = None\n        self.linear = None\n\n    def init_weight(self, pretrain_model):\n\n        state_dict = torch.load(pretrain_model, map_location='cpu')[\"state_dict\"]\n        self_state_dict = self.state_dict()\n        for k, v in state_dict.items():\n            self_state_dict.update({k: v})\n        self.load_state_dict(self_state_dict)\n\n    def init_params(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def _make_layers(self, base, layers, block_num, block, norm_layer):\n        features = []\n        features += [ConvX(3, base // 2, 3, 2)]\n        features += [ConvX(base // 2, base, 3, 2)]\n\n        for i, layer in enumerate(layers):\n            for j in range(layer):\n                if i == 0 and j == 0:\n                    features.append(block(base, base * 4, block_num, 2, norm_layer=norm_layer))\n                elif j == 0:\n                    features.append(block(base * int(math.pow(2, i + 1)), base * int(math.pow(2, i + 2)), block_num, 2,\n                                          norm_layer=norm_layer))\n                else:\n                    features.append(block(base * int(math.pow(2, i + 2)), base * int(math.pow(2, i + 2)), block_num, 1,\n                                          norm_layer=norm_layer))\n\n        return nn.Sequential(*features)\n\n    def forward(self, x):\n        feat2 = self.x2(x)\n        feat4 = self.x4(feat2)\n        feat8 = self.x8(feat4)\n        feat16 = self.x16(feat8)\n        feat32 = self.x32(feat16)\n        if self.use_conv_last:\n            feat32 = self.conv_last(feat32)\n\n        return feat4, feat8, feat16, feat32\n\n\n\n\nclass AddBottleneck(nn.Module):\n    def __init__(self, in_planes, out_planes, block_num=3, stride=1, norm_layer=nn.BatchNorm2d):\n        super(AddBottleneck, self).__init__()\n        assert block_num > 1, print(\"block number should be larger than 1.\")\n        self.conv_list = nn.ModuleList()\n        self.stride = stride\n        if stride == 2:\n            self.avd_layer = nn.Sequential(\n                nn.Conv2d(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,\n                          bias=False),\n                norm_layer(out_planes // 2),\n            )\n            self.skip = nn.Sequential(\n                nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),\n                norm_layer(in_planes),\n                nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),\n                norm_layer(out_planes),\n            )\n            stride = 1\n\n        for idx in range(block_num):\n            if idx == 0:\n                self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))\n            elif idx == 1 and block_num == 2:\n                self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))\n            elif idx == 1 and block_num > 2:\n                self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))\n            elif idx < block_num - 1:\n                self.conv_list.append(\n                    ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))))\n            else:\n                self.conv_list.append(ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))))\n\n    def forward(self, x):\n        out_list = []\n        out = x\n\n        for idx, conv in enumerate(self.conv_list):\n            if idx == 0 and self.stride == 2:\n                out = self.avd_layer(conv(out))\n            else:\n                out = conv(out)\n            out_list.append(out)\n\n        if self.stride == 2:\n            x = self.skip(x)\n\n        return torch.cat(out_list, dim=1) + x\n\n\nclass CatBottleneck(nn.Module):\n    def __init__(self, in_planes, out_planes, block_num=3, stride=1, norm_layer=nn.BatchNorm2d):\n        super(CatBottleneck, self).__init__()\n        assert block_num > 1, print(\"block number should be larger than 1.\")\n        self.conv_list = nn.ModuleList()\n        self.stride = stride\n        if stride == 2:\n            self.avd_layer = nn.Sequential(\n                nn.Conv2d(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,\n                          bias=False),\n                norm_layer(out_planes // 2),\n            )\n            self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)\n            stride = 1\n\n        for idx in range(block_num):\n            if idx == 0:\n                self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))\n            elif idx == 1 and block_num == 2:\n                self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))\n            elif idx == 1 and block_num > 2:\n                self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))\n            elif idx < block_num - 1:\n                self.conv_list.append(\n                    ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))))\n            else:\n                self.conv_list.append(ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))))\n\n    def forward(self, x):\n        out_list = []\n        out1 = self.conv_list[0](x)\n\n        for idx, conv in enumerate(self.conv_list[1:]):\n            if idx == 0:\n                if self.stride == 2:\n                    out = conv(self.avd_layer(out1))\n                else:\n                    out = conv(out1)\n            else:\n                out = conv(out)\n            out_list.append(out)\n\n        if self.stride == 2:\n            out1 = self.skip(out1)\n        out_list.insert(0, out1)\n\n        out = torch.cat(out_list, dim=1)\n        return out\n\n\nclass ConvX(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel=3, stride=1, norm_layer=nn.BatchNorm2d):\n        super(ConvX, self).__init__()\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)\n        self.bn = norm_layer(out_planes)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        out = self.relu(self.bn(self.conv(x)))\n        return out"
  },
  {
    "path": "knet/det/utils.py",
    "content": "from typing import List\n\nimport torch\nimport torch.nn.functional as F\nfrom mmdet.utils import get_root_logger\n\n\ndef sem2ins_masks(gt_sem_seg,\n                  ignore_label=255,\n                  label_shift=80,\n                  thing_label_in_seg=0):\n    classes = torch.unique(gt_sem_seg)\n    ins_masks = []\n    ins_labels = []\n    for i in classes:\n        # skip ignore class 255 and \"special thing class\" in semantic seg\n        if i == ignore_label or i == thing_label_in_seg:\n            continue\n        ins_labels.append(i)\n        ins_masks.append(gt_sem_seg == i)\n    # 0 is the special thing class in semantic seg, so we also shift it by 1\n    # Thus, 0-79 is foreground classes of things (similar in instance seg)\n    # 80-151 is foreground classes of stuffs (shifted by the original index)\n    if len(ins_labels) > 0:\n        ins_labels = torch.stack(ins_labels) + label_shift - 1\n        ins_masks = torch.cat(ins_masks)\n    else:\n        ins_labels = gt_sem_seg.new_zeros(size=[0])\n        ins_masks = gt_sem_seg.new_zeros(\n            size=[0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1]])\n    return ins_labels.long(), ins_masks.float()\n\n\ndef sem2ins_masks_cityscapes(gt_sem_seg,\n                             ignore_label=255,\n                             label_shift=8,\n                             thing_label_in_seg=list(range(11, 19))):\n    \"\"\"\n        Shift the cityscapes semantic labels to instance labels and masks.\n    \"\"\"\n    # assert label range from 0-18 (255)\n    classes = torch.unique(gt_sem_seg)\n    ins_masks = []\n    ins_labels = []\n    for i in classes:\n        # skip ignore class 255 and \"special thing class\" in semantic seg\n        if i == ignore_label or i in thing_label_in_seg:\n            continue\n        ins_labels.append(i)\n        ins_masks.append(gt_sem_seg == i)\n    # For cityscapes, 0-7 is foreground classes of things (similar in instance seg)\n    # 8-18 is foreground classes of stuffs (shifted by the original index)\n    if len(ins_labels) > 0:\n        ins_labels = torch.stack(ins_labels) + label_shift\n        ins_masks = torch.cat(ins_masks)\n    else:\n        ins_labels = gt_sem_seg.new_zeros(size=[0])\n        ins_masks = gt_sem_seg.new_zeros(\n            size=[0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1]])\n    return ins_labels.long(), ins_masks.float()\n\n\ndef sem2ins_masks_kitti_step(gt_sem_seg,\n                             ignore_label=255,\n                             label_shift=2,\n                             thing_label_in_seg=(11,13)):\n    \"\"\"\n        Shift the cityscapes semantic labels to instance labels and masks.\n    \"\"\"\n    # assert label range from 0-18 (255)\n    classes = torch.unique(gt_sem_seg)\n    ins_masks = []\n    ins_labels = []\n    for i in classes:\n        # skip ignore class 255 and \"special thing class\" in semantic seg\n        if i == ignore_label or i in thing_label_in_seg:\n            continue\n        offset = 0\n        for thing_label in thing_label_in_seg:\n            if i > thing_label:\n                offset -= 1\n        ins_labels.append(i + offset)\n        ins_masks.append(gt_sem_seg == i)\n    # For cityscapes, 0-7 is foreground classes of things (similar in instance seg)\n    # 8-18 is foreground classes of stuffs (shifted by the original index)\n    if len(ins_labels) > 0:\n        ins_labels = torch.stack(ins_labels) + label_shift\n        ins_masks = torch.cat(ins_masks)\n    else:\n        ins_labels = gt_sem_seg.new_zeros(size=[0])\n        ins_masks = gt_sem_seg.new_zeros(\n            size=[0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1]])\n    return ins_labels.long(), ins_masks.float()"
  },
  {
    "path": "knet/kernel_updator.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import build_activation_layer, build_norm_layer\nfrom mmcv.cnn.bricks.transformer import TRANSFORMER_LAYER\n\n\n@TRANSFORMER_LAYER.register_module()\nclass KernelUpdator(nn.Module):\n\n    def __init__(self,\n                 in_channels=256,\n                 feat_channels=64,\n                 out_channels=None,\n                 input_feat_shape=3,\n                 gate_sigmoid=True,\n                 gate_norm_act=False,\n                 activate_out=False,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 norm_cfg=dict(type='LN')):\n        super(KernelUpdator, self).__init__()\n        self.in_channels = in_channels\n        self.feat_channels = feat_channels\n        self.out_channels_raw = out_channels\n        self.gate_sigmoid = gate_sigmoid\n        self.gate_norm_act = gate_norm_act\n        self.activate_out = activate_out\n        if isinstance(input_feat_shape, int):\n            input_feat_shape = [input_feat_shape] * 2\n        self.input_feat_shape = input_feat_shape\n        self.act_cfg = act_cfg\n        self.norm_cfg = norm_cfg\n        self.out_channels = out_channels if out_channels else in_channels\n\n        self.num_params_in = self.feat_channels\n        self.num_params_out = self.feat_channels\n        self.dynamic_layer = nn.Linear(\n            self.in_channels, self.num_params_in + self.num_params_out)\n        self.input_layer = nn.Linear(self.in_channels,\n                                     self.num_params_in + self.num_params_out,\n                                     1)\n        self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)\n        self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)\n        if self.gate_norm_act:\n            self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]\n\n        self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]\n\n        self.activation = build_activation_layer(act_cfg)\n\n        self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)\n        self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]\n\n    def forward(self, update_feature, input_feature):\n        update_feature = update_feature.reshape(-1, self.in_channels)\n        num_proposals = update_feature.size(0)\n        parameters = self.dynamic_layer(update_feature)\n        param_in = parameters[:, :self.num_params_in].view(\n            -1, self.feat_channels)\n        param_out = parameters[:, -self.num_params_out:].view(\n            -1, self.feat_channels)\n\n        input_feats = self.input_layer(\n            input_feature.reshape(num_proposals, -1, self.feat_channels))\n        input_in = input_feats[..., :self.num_params_in]\n        input_out = input_feats[..., -self.num_params_out:]\n\n        gate_feats = input_in * param_in.unsqueeze(-2)\n        if self.gate_norm_act:\n            gate_feats = self.activation(self.gate_norm(gate_feats))\n\n        input_gate = self.input_norm_in(self.input_gate(gate_feats))\n        update_gate = self.norm_in(self.update_gate(gate_feats))\n        if self.gate_sigmoid:\n            input_gate = input_gate.sigmoid()\n            update_gate = update_gate.sigmoid()\n        param_out = self.norm_out(param_out)\n        input_out = self.input_norm_out(input_out)\n\n        if self.activate_out:\n            param_out = self.activation(param_out)\n            input_out = self.activation(input_out)\n\n        # param_out has shape (batch_size, feat_channels, out_channels)\n        features = update_gate * param_out.unsqueeze(\n            -2) + input_gate * input_out\n\n        features = self.fc_layer(features)\n        features = self.fc_norm(features)\n        features = self.activation(features)\n\n        return features"
  },
  {
    "path": "knet/video/__init__.py",
    "content": ""
  },
  {
    "path": "knet/video/dice_loss.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmdet.models.builder import LOSSES, build_loss\nfrom mmdet.models.losses.utils import weighted_loss\n\n\n@weighted_loss\ndef dice_loss(input, target, eps=1e-3, numerator_eps=0):\n    input = input.reshape(input.size()[0], -1)\n    target = target.reshape(target.size()[0], -1).float()\n\n    a = torch.sum(input * target, 1)\n    b = torch.sum(input * input, 1) + eps\n    c = torch.sum(target * target, 1) + eps\n    d = (2 * a + numerator_eps) / (b + c)\n    return 1 - d\n\n#\n# @LOSSES.register_module()\n# class DiceLoss(nn.Module):\n#\n#     def __init__(self,\n#                  eps=1e-3,\n#                  numerator_eps=0.0,\n#                  use_sigmoid=True,\n#                  reduction='mean',\n#                  loss_weight=1.0):\n#         super(DiceLoss, self).__init__()\n#         self.eps = eps\n#         self.reduction = reduction\n#         self.loss_weight = loss_weight\n#         self.use_sigmoid = use_sigmoid\n#         self.numerator_eps = numerator_eps\n#\n#     def forward(self,\n#                 pred,\n#                 target,\n#                 weight=None,\n#                 avg_factor=None,\n#                 reduction_override=None,\n#                 **kwargs):\n#         if weight is not None and not torch.any(weight > 0):\n#             return (pred * weight).sum()  # 0\n#         assert reduction_override in (None, 'none', 'mean', 'sum')\n#         reduction = (\n#             reduction_override if reduction_override else self.reduction)\n#         pred = pred.sigmoid()\n#         loss = self.loss_weight * dice_loss(\n#             pred,\n#             target,\n#             weight,\n#             eps=self.eps,\n#             numerator_eps=self.numerator_eps,\n#             reduction=reduction,\n#             avg_factor=avg_factor,\n#             **kwargs)\n#         return loss\n"
  },
  {
    "path": "knet/video/kernel_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, normal_init)\nfrom mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean\nfrom mmdet.models.builder import HEADS, build_loss, build_neck\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\n\n@HEADS.register_module()\nclass VideoConvKernelHead(nn.Module):\n    \"\"\"\n        This head for init mask and kernel prediction\n    \"\"\"\n    def __init__(self,\n                 num_proposals=100,\n                 in_channels=256,\n                 out_channels=256,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_seg_convs=1,\n                 num_loc_convs=1,\n                 att_dropout=False,\n                 localization_fpn=None,\n                 conv_kernel_size=1,\n                 norm_cfg=dict(type='GN', num_groups=32),\n                 semantic_fpn=True,\n                 train_cfg=None,\n                 num_classes=80,\n                 xavier_init_kernel=False,\n                 kernel_init_std=0.01,\n                 use_binary=False,\n                 proposal_feats_with_obj=False,\n                 loss_mask=None,\n                 loss_seg=None,\n                 loss_cls=None,\n                 loss_dice=None,\n                 loss_rank=None,\n                 feat_downsample_stride=1,\n                 feat_refine_stride=1,\n                 feat_refine=True,\n                 with_embed=False,\n                 feat_embed_only=False,\n                 conv_normal_init=False,\n                 mask_out_stride=4,\n                 hard_target=False,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cat_stuff_mask=False,\n                 link_previous=False,\n                 **kwargs):\n        super(VideoConvKernelHead, self).__init__()\n        self.num_proposals = num_proposals\n        self.num_cls_fcs = num_cls_fcs\n        self.train_cfg = train_cfg\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_classes = num_classes\n        self.proposal_feats_with_obj = proposal_feats_with_obj\n        self.sampling = False\n        self.localization_fpn = build_neck(localization_fpn)\n        self.semantic_fpn = semantic_fpn\n        self.norm_cfg = norm_cfg\n        self.num_heads = num_heads\n        self.att_dropout = att_dropout\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.conv_kernel_size = conv_kernel_size\n        self.xavier_init_kernel = xavier_init_kernel\n        self.kernel_init_std = kernel_init_std\n        self.feat_downsample_stride = feat_downsample_stride\n        self.feat_refine_stride = feat_refine_stride\n        self.conv_normal_init = conv_normal_init\n        self.feat_refine = feat_refine\n        self.with_embed = with_embed\n        self.feat_embed_only = feat_embed_only\n        self.num_loc_convs = num_loc_convs\n        self.num_seg_convs = num_seg_convs\n        self.use_binary = use_binary\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n        self.cat_stuff_mask = cat_stuff_mask\n        self.link_previous = link_previous\n\n        if loss_mask is not None:\n            self.loss_mask = build_loss(loss_mask)\n        else:\n            self.loss_mask = loss_mask\n\n        if loss_dice is not None:\n            self.loss_dice = build_loss(loss_dice)\n        else:\n            self.loss_dice = loss_dice\n\n        if loss_seg is not None:\n            self.loss_seg = build_loss(loss_seg)\n        else:\n            self.loss_seg = loss_seg\n        if loss_cls is not None:\n            self.loss_cls = build_loss(loss_cls)\n        else:\n            self.loss_cls = loss_cls\n\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        if self.train_cfg:\n            self.assigner = build_assigner(self.train_cfg.assigner)\n            # use PseudoSampler when sampling is False\n            if self.sampling and hasattr(self.train_cfg, 'sampler'):\n                sampler_cfg = self.train_cfg.sampler\n            else:\n                sampler_cfg = dict(type='MaskPseudoSampler')\n            self.sampler = build_sampler(sampler_cfg, context=self)\n        self._init_layers()\n\n    def _init_layers(self):\n        \"\"\"Initialize a sparse set of proposal boxes and proposal features.\"\"\"\n        self.init_kernels = nn.Conv2d(\n            self.out_channels,\n            self.num_proposals,\n            self.conv_kernel_size,\n            padding=int(self.conv_kernel_size // 2),\n            bias=False)  # (N, C, 1, 1) -> (N, C)\n\n        if self.semantic_fpn:\n            self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes, 1)\n\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            self.ins_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,  # 2\n                padding=1,\n                norm_cfg=self.norm_cfg)\n            self.seg_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,  # 2\n                padding=1,\n                norm_cfg=self.norm_cfg)\n\n        self.loc_convs = nn.ModuleList()\n        for i in range(self.num_loc_convs):\n            self.loc_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n        self.seg_convs = nn.ModuleList()\n        for i in range(self.num_seg_convs):\n            self.seg_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        self.localization_fpn.init_weights()\n\n        if self.feat_downsample_stride > 1 and self.conv_normal_init:\n            logger = get_root_logger()\n            logger.info('Initialize convs in KPN head by normal std 0.01')\n            for conv in [self.loc_convs, self.seg_convs]:\n                for m in conv.modules():\n                    if isinstance(m, nn.Conv2d):\n                        normal_init(m, std=0.01)\n\n        if self.semantic_fpn:\n            bias_seg = bias_init_with_prob(0.01)\n            if self.loss_seg.use_sigmoid:\n                normal_init(self.conv_seg, std=0.01, bias=bias_seg)\n            else:\n                normal_init(self.conv_seg, mean=0, std=0.01)\n        if self.xavier_init_kernel:\n            logger = get_root_logger()\n            logger.info('Initialize kernels by xavier uniform')\n            nn.init.xavier_uniform_(self.init_kernels.weight)\n        else:\n            logger = get_root_logger()\n            logger.info(\n                f'Initialize kernels by normal std: {self.kernel_init_std}')\n            normal_init(self.init_kernels, mean=0, std=self.kernel_init_std)\n\n    def _decode_init_proposals(self, img, img_metas,\n                               previous_obj_feats=None, previous_mask_preds=None, previous_x_feats=None):\n        num_imgs = len(img_metas)\n\n        localization_feats = self.localization_fpn(img)\n\n        ## thing branch\n        if isinstance(localization_feats, list):\n            loc_feats = localization_feats[0]\n        else:\n            loc_feats = localization_feats\n        for conv in self.loc_convs:\n            loc_feats = conv(loc_feats)\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            loc_feats = self.ins_downsample(loc_feats)\n\n        # init kernel prediction\n        mask_preds = self.init_kernels(loc_feats)  # init mask prediction\n\n        # stuff branch\n        if self.semantic_fpn:\n            if isinstance(localization_feats, list):\n                semantic_feats = localization_feats[1]\n            else:\n                semantic_feats = localization_feats\n            for conv in self.seg_convs:\n                semantic_feats = conv(semantic_feats)\n            if self.feat_downsample_stride > 1 and self.feat_refine:\n                semantic_feats = self.seg_downsample(semantic_feats)\n        else:\n            semantic_feats = None\n\n        if semantic_feats is not None:\n            seg_preds = self.conv_seg(semantic_feats)\n        else:\n            seg_preds = None\n\n        # init things\n        proposal_feats = self.init_kernels.weight.clone()\n        proposal_feats = proposal_feats[None].expand(num_imgs,\n                                                     *proposal_feats.size())\n\n        if semantic_feats is not None:\n            x_feats = semantic_feats + loc_feats\n        else:\n            x_feats = loc_feats\n\n        if self.proposal_feats_with_obj:\n            sigmoid_masks = mask_preds.sigmoid()\n            nonzero_inds = sigmoid_masks > 0.5\n            if self.use_binary:\n                sigmoid_masks = nonzero_inds.float()\n            else:\n                sigmoid_masks = nonzero_inds.float() * sigmoid_masks\n            obj_feats = torch.einsum('bnhw, bchw->bnc', sigmoid_masks, x_feats)\n\n        cls_scores = None\n\n        if self.proposal_feats_with_obj:  # default True\n            proposal_feats = proposal_feats + obj_feats.view(\n                num_imgs, self.num_proposals, self.out_channels, 1, 1)\n\n        if self.cat_stuff_mask and not self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs,\n                                                       *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)  # (b, N_{st}+N_{th}, c)\n\n        return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None,\n                      previous_obj_feats=None,\n                      previous_mask_preds=None,\n                      previous_x_feats=None):\n        \"\"\"Forward function in training stage.\"\"\"\n        num_imgs = len(img_metas)\n        results = self._decode_init_proposals(img, img_metas, previous_obj_feats, previous_mask_preds, previous_x_feats)\n        (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results\n        if self.feat_downsample_stride > 1:\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=self.feat_downsample_stride,\n                mode='bilinear',\n                align_corners=False)\n            if seg_preds is not None:\n                scaled_seg_preds = F.interpolate(\n                    seg_preds,\n                    scale_factor=self.feat_downsample_stride,\n                    mode='bilinear',\n                    align_corners=False)\n        else:\n            scaled_mask_preds = mask_preds  # thing\n            scaled_seg_preds = seg_preds   # stuff\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        sampling_results = []\n        if cls_scores is None:\n            detached_cls_scores = [None] * num_imgs\n        else:\n            detached_cls_scores = cls_scores.detach()\n\n        for i in range(num_imgs):\n            assign_result = self.assigner.assign(scaled_mask_preds[i].detach(),\n                                                 detached_cls_scores[i],\n                                                 gt_masks[i], gt_labels[i],\n                                                 img_meta=img_metas[i])\n            sampling_result = self.sampler.sample(assign_result,\n                                                  scaled_mask_preds[i],\n                                                  gt_masks[i])\n            sampling_results.append(sampling_result)\n\n        mask_targets = self.get_targets(\n            sampling_results,\n            gt_masks,\n            self.train_cfg,\n            True,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls)\n\n        losses = self.loss(scaled_mask_preds, cls_scores, scaled_seg_preds,\n                           proposal_feats, *mask_targets)\n\n        if self.cat_stuff_mask and self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs,\n                                                       *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return losses, proposal_feats, x_feats, mask_preds, cls_scores\n\n    def loss(self,\n             mask_pred,\n             cls_scores,\n             seg_preds,\n             proposal_feats,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             seg_targets,\n             reduction_override=None,\n             **kwargs):\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n\n        if cls_scores is not None:\n            num_pos = pos_inds.sum().float()\n            avg_factor = reduce_mean(num_pos)\n            assert mask_pred.shape[0] == cls_scores.shape[0]\n            assert mask_pred.shape[1] == cls_scores.shape[1]\n            losses['loss_rpn_cls'] = self.loss_cls(\n                cls_scores.view(num_preds, -1),\n                labels,\n                label_weights,\n                avg_factor=avg_factor,\n                reduction_override=reduction_override)\n            losses['rpn_pos_acc'] = accuracy(\n                cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds])\n\n        bool_pos_inds = pos_inds.type(torch.bool)\n        # 0~self.num_classes-1 are FG, self.num_classes is BG\n        # do not perform bounding box regression for BG anymore.\n        H, W = mask_pred.shape[-2:]\n        if pos_inds.any():\n            pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds]\n            pos_mask_targets = mask_targets[bool_pos_inds]\n            losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n            losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n            if self.loss_rank is not None:\n                batch_size = mask_pred.size(0)\n                rank_target = mask_targets.new_full((batch_size, H, W),\n                                                    self.ignore_label,\n                                                    dtype=torch.long)\n                rank_inds = pos_inds.view(batch_size,\n                                          -1).nonzero(as_tuple=False)\n                batch_mask_targets = mask_targets.view(batch_size, -1, H,\n                                                       W).bool()\n                for i in range(batch_size):\n                    curr_inds = (rank_inds[:, 0] == i)\n                    curr_rank = rank_inds[:, 1][curr_inds]\n                    for j in curr_rank:\n                        rank_target[i][batch_mask_targets[i][j]] = j\n                losses['loss_rpn_rank'] = self.loss_rank(\n                    mask_pred, rank_target, ignore_index=self.ignore_label)\n\n        else:\n            losses['loss_rpn_mask'] = mask_pred.sum() * 0\n            losses['loss_rpn_dice'] = mask_pred.sum() * 0\n            if self.loss_rank is not None:\n                losses['loss_rank'] = mask_pred.sum() * 0\n\n        if seg_preds is not None:\n            if self.loss_seg.use_sigmoid:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(\n                    -1, cls_channel,\n                    H * W).permute(0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                num_dense_pos = (flatten_seg_target >= 0) & (\n                    flatten_seg_target < bg_class_ind)\n                num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0)\n                losses['loss_rpn_seg'] = self.loss_seg(\n                    flatten_seg,\n                    flatten_seg_target,\n                    avg_factor=num_dense_pos)\n            else:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(\n                    0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                losses['loss_rpn_seg'] = self.loss_seg(flatten_seg,\n                                                       flatten_seg_target, ignore_index=self.num_classes)\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros(num_samples)\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        seg_targets = pos_mask.new_full((H, W),\n                                        self.num_classes,\n                                        dtype=torch.long)\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            gt_sem_seg = gt_sem_seg.bool()\n            for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls):\n                seg_targets[sem_mask] = sem_cls.long()\n\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            mask_targets[pos_inds, ...] = pos_gt_mask\n            mask_weights[pos_inds, ...] = 1\n            for i in range(num_pos):\n                seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i]\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def get_targets(self,\n                    sampling_results,\n                    gt_mask,\n                    rpn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * 2\n            gt_sem_cls = [None] * 2\n        results = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rpn_train_cfg)\n        (labels, label_weights, mask_targets, mask_weights,\n         seg_targets) = results\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n            seg_targets = torch.stack(seg_targets, 0)\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def simple_test_rpn(self, img, img_metas,\n            previous_obj_feats=None, previous_mask_preds=None, previous_x_feats=None):\n        \"\"\"Forward function in testing stage.\"\"\"\n        return self._decode_init_proposals(img, img_metas, previous_obj_feats, previous_mask_preds, previous_x_feats)\n\n    def forward_dummy(self, img, img_metas):\n        \"\"\"Dummy forward function.\n\n        Used in flops calculation.\n        \"\"\"\n        return self._decode_init_proposals(img, img_metas)\n"
  },
  {
    "path": "knet/video/kernel_iter_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmdet.core import build_assigner, build_sampler\nfrom mmdet.models.builder import HEADS, build_head\nfrom mmdet.models.roi_heads import BaseRoIHead\nfrom knet.det.mask_pseudo_sampler import MaskPseudoSampler\n\n\n@HEADS.register_module()\nclass VideoKernelIterHead(BaseRoIHead):\n\n    def __init__(self,\n                 num_stages=6,\n                 recursive=False,\n                 assign_stages=5,\n                 stage_loss_weights=(1, 1, 1, 1, 1, 1),\n                 proposal_feature_channel=256,\n                 merge_cls_scores=False,\n                 do_panoptic=False,\n                 post_assign=False,\n                 hard_target=False,\n                 merge_joint=False,\n                 num_proposals=100,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 with_track=False,\n                 mask_head=dict(\n                     type='KernelUpdateHead',\n                     num_classes=80,\n                     num_fcs=2,\n                     num_heads=8,\n                     num_cls_fcs=1,\n                     num_reg_fcs=3,\n                     feedforward_channels=2048,\n                     hidden_channels=256,\n                     dropout=0.0,\n                     roi_feat_size=7,\n                     ffn_act_cfg=dict(type='ReLU', inplace=True)),\n                 mask_out_stride=4,\n                 train_cfg=None,\n                 test_cfg=None,\n                 **kwargs):\n        assert mask_head is not None\n        assert len(stage_loss_weights) == num_stages\n        self.num_stages = num_stages\n        self.stage_loss_weights = stage_loss_weights\n        self.proposal_feature_channel = proposal_feature_channel\n        self.merge_cls_scores = merge_cls_scores\n        self.recursive = recursive\n        self.post_assign = post_assign\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.merge_joint = merge_joint\n        self.assign_stages = assign_stages\n        self.do_panoptic = do_panoptic\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.num_proposals = num_proposals\n        self.ignore_label = ignore_label\n        self.with_track = with_track\n        super(VideoKernelIterHead, self).__init__(\n            mask_head=mask_head, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)\n        # train_cfg would be None when run the test.py\n        if train_cfg is not None:\n            for stage in range(num_stages):\n                assert isinstance(\n                    self.mask_sampler[stage], MaskPseudoSampler), \\\n                    'Sparse Mask only support `MaskPseudoSampler`'\n\n    def init_bbox_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize box head and box roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of box roi extractor.\n            mask_head (dict): Config of box in box head.\n        \"\"\"\n        pass\n\n    def init_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler for each stage.\"\"\"\n        self.mask_assigner = []\n        self.mask_sampler = []\n        if self.train_cfg is not None:\n            for idx, rcnn_train_cfg in enumerate(self.train_cfg):\n                self.mask_assigner.append(\n                    build_assigner(rcnn_train_cfg.assigner))\n                self.current_stage = idx\n                self.mask_sampler.append(\n                    build_sampler(rcnn_train_cfg.sampler, context=self))\n\n    def init_weights(self):\n        for i in range(self.num_stages):\n            self.mask_head[i].init_weights()\n\n    def init_mask_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize mask head and mask roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of mask roi extractor.\n            mask_head (dict): Config of mask in mask head.\n        \"\"\"\n        self.mask_head = nn.ModuleList()\n        if not isinstance(mask_head, list):\n            mask_head = [mask_head for _ in range(self.num_stages)]\n        assert len(mask_head) == self.num_stages\n        for head in mask_head:\n            self.mask_head.append(build_head(head))\n        if self.recursive:\n            for i in range(self.num_stages):\n                self.mask_head[i] = self.mask_head[0]\n\n    def _mask_forward(self, stage, x, object_feats, mask_preds, img_metas,\n                      previous_obj_feats=None,\n                      previous_mask_preds=None,\n                      previous_x_feats=None\n                      ):\n        mask_head = self.mask_head[stage]\n        cls_score, mask_preds, object_feats, x_feats, object_feats_track = mask_head(\n            x, object_feats, mask_preds, img_metas=img_metas,\n            previous_obj_feats=previous_obj_feats,\n            previous_mask_preds=previous_mask_preds,\n            previous_x_feats=previous_x_feats\n        )\n        if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1\n                                                   or self.training):\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=mask_head.mask_upsample_stride,\n                align_corners=False,\n                mode='bilinear')\n        else:\n            scaled_mask_preds = mask_preds\n        mask_results = dict(\n            cls_score=cls_score,\n            mask_preds=mask_preds,\n            scaled_mask_preds=scaled_mask_preds,\n            object_feats=object_feats,\n            object_feats_track=object_feats_track,\n            x_feats=x_feats,\n        )\n\n        return mask_results\n\n    def forward_train(self,\n                      x,\n                      proposal_feats,\n                      mask_preds,\n                      cls_score,\n                      img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_pids=None,\n                      gt_bboxes_ignore=None,\n                      imgs_whwh=None,\n                      gt_bboxes=None,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n\n        num_imgs = len(img_metas)\n        if self.mask_head[0].mask_upsample_stride > 1:\n            prev_mask_preds = F.interpolate(\n                mask_preds.detach(),\n                scale_factor=self.mask_head[0].mask_upsample_stride,\n                mode='bilinear',\n                align_corners=False)\n        else:\n            prev_mask_preds = mask_preds.detach()\n\n        if cls_score is not None:\n            prev_cls_score = cls_score.detach()\n        else:\n            prev_cls_score = [None] * num_imgs\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        object_feats = proposal_feats\n        all_stage_loss = {}\n        all_stage_mask_results = []\n        assign_results = []\n        final_sample_results = []\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score']\n            object_feats = mask_results['object_feats']\n            object_feats_track = mask_results['object_feats_track']\n\n            if self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            sampling_results = []\n            if stage < self.assign_stages:\n                assign_results = []\n            for i in range(num_imgs):\n                if stage < self.assign_stages:\n                    mask_for_assign = prev_mask_preds[i][:self.num_proposals]\n                    if prev_cls_score[i] is not None:\n                        cls_for_assign = prev_cls_score[\n                            i][:self.num_proposals, :self.num_thing_classes]\n                    else:\n                        cls_for_assign = None\n\n                    assign_result = self.mask_assigner[stage].assign(\n                        mask_for_assign, cls_for_assign, gt_masks[i],\n                        gt_labels[i], img_meta=img_metas[i])\n                    assign_results.append(assign_result)\n                sampling_result = self.mask_sampler[stage].sample(\n                    assign_results[i], scaled_mask_preds[i], gt_masks[i])\n                sampling_results.append(sampling_result)\n\n            mask_targets = self.mask_head[stage].get_targets(\n                sampling_results,\n                gt_masks,\n                gt_labels,\n                self.train_cfg[stage],\n                True,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls)\n\n            single_stage_loss = self.mask_head[stage].loss(\n                object_feats,\n                cls_score,\n                scaled_mask_preds,\n                *mask_targets,\n                imgs_whwh=imgs_whwh)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f's{stage}_{key}'] = value * \\\n                                    self.stage_loss_weights[stage]\n\n            if not self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            if stage == self.num_stages - 1:\n                final_sample_results.extend(sampling_results)\n\n        if self.with_track:\n            return all_stage_loss, object_feats, cls_score, mask_preds, scaled_mask_preds\n        else:\n            return all_stage_loss\n\n    def forward_train_with_previous(self,\n                                      x,\n                                      proposal_feats,\n                                      mask_preds,\n                                      cls_score,\n                                      img_metas,\n                                      gt_masks,\n                                      gt_labels,\n                                      gt_pids=None,\n                                      gt_bboxes_ignore=None,\n                                      imgs_whwh=None,\n                                      gt_bboxes=None,\n                                      gt_sem_seg=None,\n                                      gt_sem_cls=None,\n                                      previous_obj_feats=None,\n                                      previous_mask_preds=None,\n                                      previous_x_feats=None,\n                                    ):\n\n        num_imgs = len(img_metas)\n        if self.mask_head[0].mask_upsample_stride > 1:\n            prev_mask_preds = F.interpolate(\n                mask_preds.detach(),\n                scale_factor=self.mask_head[0].mask_upsample_stride,\n                mode='bilinear',\n                align_corners=False)\n        else:\n            prev_mask_preds = mask_preds.detach()\n\n        if cls_score is not None:\n            prev_cls_score = cls_score.detach()\n        else:\n            prev_cls_score = [None] * num_imgs\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        object_feats = proposal_feats\n        all_stage_loss = {}\n        all_stage_mask_results = []\n        assign_results = []\n        final_sample_results = []\n        for stage in range(self.num_stages):\n\n            # only link the last stage\n            previous_obj_feats_cur = previous_obj_feats if stage == self.num_stages - 1 else None\n            previous_mask_preds_cur = previous_mask_preds if stage == self.num_stages - 1 else None\n            previous_x_feats_cur = previous_x_feats if stage == self.num_stages - 1 else None\n\n            # only link the first stage\n            # previous_obj_feats_cur = previous_obj_feats if stage == 0 else None\n            # previous_mask_preds_cur = previous_mask_preds if stage == 0 else None\n            # previous_x_feats_cur = previous_x_feats if stage == 0 else None\n\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas,\n                                              previous_obj_feats=previous_obj_feats_cur,\n                                              previous_mask_preds=previous_mask_preds_cur,\n                                              previous_x_feats=previous_x_feats_cur)\n            all_stage_mask_results.append(mask_results)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score']\n            object_feats = mask_results['object_feats']\n            object_feats_track = mask_results['object_feats_track']\n\n            if self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            sampling_results = []\n            if stage < self.assign_stages:\n                assign_results = []\n            for i in range(num_imgs):\n                if stage < self.assign_stages:\n                    mask_for_assign = prev_mask_preds[i][:self.num_proposals]\n                    if prev_cls_score[i] is not None:\n                        cls_for_assign = prev_cls_score[\n                            i][:self.num_proposals, :self.num_thing_classes]\n                    else:\n                        cls_for_assign = None\n\n                    assign_result = self.mask_assigner[stage].assign(\n                        mask_for_assign, cls_for_assign, gt_masks[i],\n                        gt_labels[i], img_meta=img_metas[i])\n                    assign_results.append(assign_result)\n                sampling_result = self.mask_sampler[stage].sample(\n                    assign_results[i], scaled_mask_preds[i], gt_masks[i])\n                sampling_results.append(sampling_result)\n\n            mask_targets = self.mask_head[stage].get_targets(\n                sampling_results,\n                gt_masks,\n                gt_labels,\n                self.train_cfg[stage],\n                True,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls)\n\n            single_stage_loss = self.mask_head[stage].loss(\n                object_feats,\n                cls_score,\n                scaled_mask_preds,\n                *mask_targets,\n                imgs_whwh=imgs_whwh)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f's{stage}_{key}'] = value * \\\n                                    self.stage_loss_weights[stage]\n\n            if not self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            if stage == self.num_stages - 1:\n                final_sample_results.extend(sampling_results)\n\n        if self.with_track:\n            return all_stage_loss, object_feats, cls_score, mask_preds, scaled_mask_preds, object_feats_track\n        else:\n            return all_stage_loss\n\n    def simple_test(self,\n                    x,\n                    proposal_feats,\n                    mask_preds,\n                    cls_score,\n                    img_metas):\n\n        # Decode initial proposals\n        num_imgs = len(img_metas)\n        # num_proposals = proposal_feats.size(1)\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            object_feats_track = mask_results['object_feats_track']\n\n        num_classes = self.mask_head[-1].num_classes\n        results = []\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        if self.do_panoptic:\n            for img_id in range(num_imgs):\n                single_result = self.get_panoptic(cls_score[img_id],\n                                                  scaled_mask_preds[img_id],\n                                                  self.test_cfg,\n                                                  img_metas[img_id],\n                                                  object_feats[img_id]\n                                                  )\n                results.append(single_result)\n        else:\n            for img_id in range(num_imgs):\n                cls_score_per_img = cls_score[img_id]\n                scores_per_img, topk_indices = cls_score_per_img.flatten(\n                    0, 1).topk(\n                        self.test_cfg.max_per_img, sorted=True)\n                mask_indices = topk_indices // num_classes\n                labels_per_img = topk_indices % num_classes\n                masks_per_img = scaled_mask_preds[img_id][mask_indices]\n                single_result = self.mask_head[-1].get_seg_masks(\n                    masks_per_img, labels_per_img, scores_per_img,\n                    self.test_cfg, img_metas[img_id])\n                results.append(single_result)\n\n        if self.with_track:\n            return results, object_feats, cls_score, mask_preds, scaled_mask_preds\n        else:\n            return results\n\n    def simple_test_with_previous(self,\n                                    x,\n                                    proposal_feats,\n                                    mask_preds,\n                                    cls_score,\n                                    img_metas,\n                                  previous_obj_feats=None,\n                                  previous_mask_preds=None,\n                                  previous_x_feats=None,\n                                  is_first=False,\n                                  ):\n\n        # Decode initial proposals\n        num_imgs = len(img_metas)\n        # num_proposals = proposal_feats.size(1)\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            # only link the last stage inputs\n            previous_obj_feats_cur = previous_obj_feats if stage == self.num_stages - 1 else None\n            previous_mask_preds_cur = previous_mask_preds if stage == self.num_stages - 1 else None\n            previous_x_feats_cur = previous_x_feats if stage == self.num_stages - 1 else None\n\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas,\n                                              previous_obj_feats=previous_obj_feats_cur,\n                                              previous_mask_preds=previous_mask_preds_cur,\n                                              previous_x_feats=previous_x_feats_cur\n                                              )\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            object_feats_track = mask_results['object_feats_track']\n\n        num_classes = self.mask_head[-1].num_classes\n        results = []\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        if is_first:\n            object_feats_track = object_feats\n\n        if self.do_panoptic:\n            for img_id in range(num_imgs):\n                single_result = self.get_panoptic(cls_score[img_id],\n                                                  scaled_mask_preds[img_id],\n                                                  self.test_cfg,\n                                                  img_metas[img_id],\n                                                  object_feats_track[img_id])\n                results.append(single_result)\n        else:\n            for img_id in range(num_imgs):\n                cls_score_per_img = cls_score[img_id]\n                scores_per_img, topk_indices = cls_score_per_img.flatten(\n                    0, 1).topk(\n                        self.test_cfg.max_per_img, sorted=True)\n                mask_indices = topk_indices // num_classes\n                labels_per_img = topk_indices % num_classes\n                masks_per_img = scaled_mask_preds[img_id][mask_indices]\n                single_result = self.mask_head[-1].get_seg_masks(\n                    masks_per_img, labels_per_img, scores_per_img,\n                    self.test_cfg, img_metas[img_id])\n                results.append(single_result)\n\n        if self.with_track:\n            return results, object_feats, cls_score, mask_preds, scaled_mask_preds\n        else:\n            return results\n\n    def simple_test_mask_preds(self,\n                    x,\n                    proposal_feats,\n                    mask_preds,\n                    cls_score,\n                    img_metas):\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        return object_feats, cls_score, mask_preds, scaled_mask_preds\n\n    def simple_test_mask_preds_plus_previous(\n            self,\n            x,\n            proposal_feats,\n            mask_preds,\n            cls_score,\n            img_metas,\n            previous_obj_feats=None,\n            previous_mask_preds=None,\n            previous_x_feats=None,\n        ):\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            previous_obj_feats_cur = previous_obj_feats if stage == self.num_stages - 1 else None\n            previous_mask_preds_cur = previous_mask_preds if stage == self.num_stages - 1 else None\n            previous_x_feats_cur = previous_x_feats if stage == self.num_stages - 1 else None\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas,\n                                              previous_obj_feats=previous_obj_feats_cur,\n                                              previous_mask_preds=previous_mask_preds_cur,\n                                              previous_x_feats=previous_x_feats_cur\n                                              )\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        return object_feats, cls_score, mask_preds, scaled_mask_preds\n\n    def get_masked_feature(self, x, mask_pred):\n        sigmoid_masks = mask_pred.sigmoid()\n        nonzero_inds = sigmoid_masks > 0.5\n        sigmoid_masks = nonzero_inds.float()\n        x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)\n        return x_feat\n\n    def aug_test(self, features, proposal_list, img_metas, rescale=False):\n        raise NotImplementedError('SparseMask does not support `aug_test`')\n\n    def forward_dummy(self, x, proposal_boxes, proposal_feats, img_metas):\n        \"\"\"Dummy forward function when do the flops computing.\"\"\"\n        all_stage_mask_results = []\n        num_imgs = len(img_metas)\n        num_proposals = proposal_feats.size(1)\n        C, H, W = x.shape[-3:]\n        mask_preds = proposal_feats.bmm(x.view(num_imgs, C, -1)).view(\n            num_imgs, num_proposals, H, W)\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n        return all_stage_mask_results\n\n    def get_panoptic(self, cls_scores, mask_preds, test_cfg, img_meta, obj_feat=None):\n        # resize mask predictions back\n        thing_scores = cls_scores[:self.num_proposals][:, :self.\n                                                       num_thing_classes]\n        thing_mask_preds = mask_preds[:self.num_proposals]\n        thing_scores, topk_indices = thing_scores.flatten(0, 1).topk(\n            self.test_cfg.max_per_img, sorted=True)\n        mask_indices = topk_indices // self.num_thing_classes\n        thing_labels = topk_indices % self.num_thing_classes\n        masks_per_img = thing_mask_preds[mask_indices]\n        thing_masks = self.mask_head[-1].rescale_masks(masks_per_img, img_meta)\n\n        # thing obj_feat\n        thing_obj_feat = obj_feat[:self.num_proposals]\n        thing_obj_feat = thing_obj_feat[mask_indices]\n\n        if not self.merge_joint:\n            thing_masks = thing_masks > test_cfg.mask_thr\n        bbox_result, segm_result, thing_mask_preds = self.mask_head[-1].segm2result(\n            thing_masks, thing_labels, thing_scores)\n\n        stuff_scores = cls_scores[\n            self.num_proposals:][:, self.num_thing_classes:].diag()\n        stuff_scores, stuff_inds = torch.sort(stuff_scores, descending=True)\n        stuff_masks = mask_preds[self.num_proposals:][stuff_inds]\n        stuff_masks = self.mask_head[-1].rescale_masks(stuff_masks, img_meta)\n\n        # stuff obj_feat\n        stuff_obj_feat = obj_feat[self.num_proposals:][stuff_inds]\n\n        if not self.merge_joint:\n            stuff_masks = stuff_masks > test_cfg.mask_thr\n\n        if self.merge_joint:\n            stuff_labels = stuff_inds + self.num_thing_classes\n            panoptic_result, thing_obj_feat = self.merge_stuff_thing_stuff_joint(thing_masks, thing_labels,\n                                                                 thing_scores, stuff_masks,\n                                                                 stuff_labels, stuff_scores,\n                                                                 test_cfg.merge_stuff_thing,\n                                                                 thing_obj_feat, stuff_obj_feat\n                                                                 )\n        else:\n            stuff_labels = stuff_inds + 1\n            panoptic_result, thing_obj_feat = self.merge_stuff_thing_thing_first(thing_masks, thing_labels,\n                                                 thing_scores, stuff_masks,\n                                                 stuff_labels, stuff_scores,\n                                                 test_cfg.merge_stuff_thing,\n                                                thing_obj_feat, stuff_obj_feat)\n\n        return bbox_result, segm_result, thing_mask_preds,  panoptic_result, thing_obj_feat\n\n    def split_thing_stuff(self, mask_preds, det_labels, cls_scores):\n        thing_scores = cls_scores[:self.num_proposals]\n        thing_masks = mask_preds[:self.num_proposals]\n        thing_labels = det_labels[:self.num_proposals]\n\n        stuff_labels = det_labels[self.num_proposals:]\n        stuff_labels = stuff_labels - self.num_thing_classes + 1\n        stuff_masks = mask_preds[self.num_proposals:]\n        stuff_scores = cls_scores[self.num_proposals:]\n\n        results = (thing_masks, thing_labels, thing_scores, stuff_masks,\n                   stuff_labels, stuff_scores)\n        return results\n\n    def merge_stuff_thing_thing_first(self,\n                          thing_masks,\n                          thing_labels,\n                          thing_scores,\n                          stuff_masks,\n                          stuff_labels,\n                          stuff_scores,\n                          merge_cfg=None,\n                          thing_obj_feat=None,\n                          stuff_obj_feat=None):\n\n        H, W = thing_masks.shape[-2:]\n        panoptic_seg = thing_masks.new_zeros((H, W), dtype=torch.int32)\n        thing_masks = thing_masks.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n        stuff_masks = stuff_masks.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-thing_scores)\n        thing_obj_feat = thing_obj_feat[sorted_inds]\n        current_segment_id = 0\n        segments_info = []\n        instance_ids = []\n\n        # Add instances one-by-one, check for overlaps with existing ones\n        for inst_id in sorted_inds:\n            score = thing_scores[inst_id].item()\n            if score < merge_cfg.instance_score_thr:\n                break\n            mask = thing_masks[inst_id]  # H,W\n            mask_area = mask.sum().item()\n\n            if mask_area == 0:\n                continue\n\n            intersect = (mask > 0) & (panoptic_seg > 0)\n            intersect_area = intersect.sum().item()\n\n            if intersect_area * 1.0 / mask_area > merge_cfg.iou_thr:\n                continue\n\n            if intersect_area > 0:\n                mask = mask & (panoptic_seg == 0)\n\n            mask_area = mask.sum().item()\n            if mask_area == 0:\n                continue\n\n            current_segment_id += 1\n            panoptic_seg[mask.bool()] = current_segment_id\n            segments_info.append({\n                'id': current_segment_id,\n                'isthing': True,\n                'score': score,\n                'category_id': thing_labels[inst_id].item(),\n                'instance_id': inst_id.item(),\n            })\n            instance_ids.append(inst_id.item())\n\n        # Add semantic results to remaining empty areas\n        sorted_inds = torch.argsort(-stuff_scores)\n        sorted_stuff_labels = stuff_labels[sorted_inds]\n        # paste semantic masks following the order of scores\n        processed_label = []\n        for semantic_label in sorted_stuff_labels:\n            semantic_label = semantic_label.item()\n            if semantic_label in processed_label:\n                continue\n            processed_label.append(semantic_label)\n            sem_inds = stuff_labels == semantic_label\n            sem_masks = stuff_masks[sem_inds].sum(0).bool()\n            mask = sem_masks & (panoptic_seg == 0)\n            mask_area = mask.sum().item()\n            if mask_area < merge_cfg.stuff_max_area:\n                continue\n\n            current_segment_id += 1\n            panoptic_seg[mask] = current_segment_id\n            segments_info.append({\n                'id': current_segment_id,\n                'isthing': False,\n                'category_id': semantic_label,\n                'area': mask_area,\n            })\n        return (panoptic_seg.cpu().numpy(), segments_info), thing_obj_feat[instance_ids]\n\n    def merge_stuff_thing_stuff_first(self,\n                          thing_masks,\n                          thing_labels,\n                          thing_scores,\n                          stuff_masks,\n                          stuff_labels,\n                          stuff_scores,\n                          merge_cfg=None,\n                          thing_obj_feat=None,\n                          stuff_obj_feat=None):\n\n        H, W = thing_masks.shape[-2:]\n        panoptic_seg = thing_masks.new_zeros((H, W), dtype=torch.int32)\n        thing_masks = thing_masks.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n        stuff_masks = stuff_masks.to(\n            dtype=torch.bool, device=panoptic_seg.device)\n\n        current_segment_id = 0\n        segments_info = []\n\n        # Add semantic results first\n        sorted_inds = torch.argsort(-stuff_scores)\n        sorted_stuff_labels = stuff_labels[sorted_inds]\n        # paste semantic masks following the order of scores\n        processed_label = []\n        for semantic_label in sorted_stuff_labels:\n            semantic_label = semantic_label.item()\n            if semantic_label in processed_label:\n                continue\n            processed_label.append(semantic_label)\n            sem_inds = stuff_labels == semantic_label\n            sem_masks = stuff_masks[sem_inds].sum(0).bool()\n            mask = sem_masks & (panoptic_seg == 0)\n            mask_area = mask.sum().item()\n            if mask_area < merge_cfg.stuff_max_area:\n                continue\n\n            current_segment_id += 1\n            panoptic_seg[mask] = current_segment_id\n            segments_info.append({\n                'id': current_segment_id,\n                'isthing': False,\n                'category_id': semantic_label,\n                'area': mask_area,\n            })\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-thing_scores)\n        # thing obj feat\n        thing_obj_feat = thing_obj_feat[sorted_inds]\n        # Add instances one-by-one, check for overlaps with existing ones\n        instance_ids = []\n        for inst_id in sorted_inds:\n            score = thing_scores[inst_id].item()\n            if score < merge_cfg.instance_score_thr:\n                break\n            mask = thing_masks[inst_id]  # H,W\n            mask_area = mask.sum().item()\n\n            if mask_area == 0:\n                continue\n\n            intersect = (mask > 0) & (panoptic_seg > 0)\n            intersect_area = intersect.sum().item()\n\n            if intersect_area * 1.0 / mask_area > merge_cfg.iou_thr:\n                continue\n\n            if intersect_area > 0:\n                mask = mask & (panoptic_seg == 0)\n\n            mask_area = mask.sum().item()\n            if mask_area == 0:\n                continue\n\n            current_segment_id += 1\n            panoptic_seg[mask.bool()] = current_segment_id\n            segments_info.append({\n                'id': current_segment_id,\n                'isthing': True,\n                'score': score,\n                'category_id': thing_labels[inst_id].item(),\n                'instance_id': inst_id.item(),\n            })\n            instance_ids.append(inst_id.item())\n\n        return (panoptic_seg.cpu().numpy(), segments_info), thing_obj_feat[instance_ids]\n\n    def merge_stuff_thing_stuff_joint(self,\n                                      thing_masks,\n                                      thing_labels,\n                                      thing_scores,\n                                      stuff_masks,\n                                      stuff_labels,\n                                      stuff_scores,\n                                      merge_cfg=None,\n                                      thing_obj=None,\n                                      stuff_obj=None\n                                      ):\n\n        H, W = thing_masks.shape[-2:]\n        panoptic_seg = thing_masks.new_zeros((H, W), dtype=torch.int32)\n\n        total_masks = torch.cat([thing_masks, stuff_masks], dim=0)\n        total_scores = torch.cat([thing_scores, stuff_scores], dim=0)\n        total_labels = torch.cat([thing_labels, stuff_labels], dim=0)\n        obj_fea = torch.cat([thing_obj, stuff_obj], dim=0)\n\n        cur_prob_masks = total_scores.view(-1, 1, 1) * total_masks\n        segments_info = []\n        cur_mask_ids = cur_prob_masks.argmax(0)\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-total_scores)\n        current_segment_id = 0\n        sort_obj_fea = obj_fea\n        things_ids = []\n        for k in sorted_inds:\n            pred_class = total_labels[k].item()\n            isthing = pred_class < self.num_thing_classes\n            if isthing and total_scores[k] < merge_cfg.instance_score_thr:\n                continue\n\n            mask = cur_mask_ids == k\n            mask_area = mask.sum().item()\n            original_area = (total_masks[k] >= 0.5).sum().item()\n\n            if mask_area > 0 and original_area > 0:\n                if mask_area / original_area < merge_cfg.overlap_thr:\n                    continue\n                current_segment_id += 1\n\n                panoptic_seg[mask] = current_segment_id\n\n                if isthing:\n                    segments_info.append({\n                        'id': current_segment_id,\n                        'isthing': isthing,\n                        'score': total_scores[k].item(),\n                        'category_id': pred_class,  # 0, num_thing - 1\n                        'instance_id': k.item(),\n                    })\n                    things_ids.append(k.item())\n                else:\n                    segments_info.append({\n                        'id': current_segment_id,\n                        'isthing': isthing,\n                        'category_id': pred_class - self.num_thing_classes + 1, # 1, num_stuff\n                        'area': mask_area,\n                    })\n\n        return (panoptic_seg.cpu().numpy(), segments_info), sort_obj_fea[things_ids]"
  },
  {
    "path": "knet/video/kernel_update_head.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob,\n                      build_activation_layer, build_norm_layer)\nfrom mmcv.runner import force_fp32\nfrom mmdet.core import multi_apply\nfrom mmdet.models.builder import HEADS, build_loss\nfrom mmdet.models.dense_heads.atss_head import reduce_mean\nfrom mmdet.models.losses import accuracy\nfrom mmcv.cnn.bricks.transformer import FFN, MultiheadAttention, build_transformer_layer\nfrom mmdet.utils import get_root_logger\nfrom unitrack.mask import mask2box, tensor_mask2box\n\n\n@HEADS.register_module()\nclass VideoKernelUpdateHead(nn.Module):\n\n    def __init__(self,\n                 num_classes=80,\n                 num_ffn_fcs=2,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_mask_fcs=3,\n                 feedforward_channels=2048,\n                 in_channels=256,\n                 out_channels=256,\n                 dropout=0.0,\n                 mask_thr=0.5,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 ffn_act_cfg=dict(type='ReLU', inplace=True),\n                 conv_kernel_size=3,\n                 feat_transform_cfg=None,\n                 hard_mask_thr=0.5,\n                 kernel_init=False,\n                 with_ffn=True,\n                 mask_out_stride=4,\n                 relative_coors=False,\n                 relative_coors_off=False,\n                 feat_gather_stride=1,\n                 mask_transform_stride=1,\n                 mask_upsample_stride=1,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 previous=None,\n                 previous_x_feat=None,\n                 previous_link=None,  # seg/cls embeddings\n                 previous_type=None,  # tracking embeddings\n                 previous_detach=False,\n                 previous_detach_link=False,  # whether detach linl query\n                 previous_link_detach=False,\n                 kernel_updator_cfg=dict(\n                     type='DynamicConv',\n                     in_channels=256,\n                     feat_channels=64,\n                     out_channels=256,\n                     input_feat_shape=1,\n                     act_cfg=dict(type='ReLU', inplace=True),\n                     norm_cfg=dict(type='LN')),\n                 loss_rank=None,\n                 loss_mask=dict(\n                     type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),\n                 loss_dice=dict(type='DiceLoss', loss_weight=3.0),\n                 loss_cls=dict(\n                     type='FocalLoss',\n                     use_sigmoid=True,\n                     gamma=2.0,\n                     alpha=0.25,\n                     loss_weight=2.0)):\n        super(VideoKernelUpdateHead, self).__init__()\n        self.num_classes = num_classes\n        self.loss_cls = build_loss(loss_cls)\n        self.loss_mask = build_loss(loss_mask)\n        self.loss_dice = build_loss(loss_dice)\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.mask_thr = mask_thr\n        self.fp16_enabled = False\n        self.dropout = dropout\n\n        self.num_heads = num_heads\n        self.hard_mask_thr = hard_mask_thr\n        self.kernel_init = kernel_init\n        self.with_ffn = with_ffn\n        self.mask_out_stride = mask_out_stride\n        self.relative_coors = relative_coors\n        self.relative_coors_off = relative_coors_off\n        self.conv_kernel_size = conv_kernel_size\n        self.feat_gather_stride = feat_gather_stride\n        self.mask_transform_stride = mask_transform_stride\n        self.mask_upsample_stride = mask_upsample_stride\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n\n        self.attention = MultiheadAttention(\n            in_channels * conv_kernel_size ** 2, num_heads, dropout)\n        self.attention_norm = build_norm_layer(\n            dict(type='LN'), in_channels * conv_kernel_size ** 2)[1]\n\n        self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)\n\n        if feat_transform_cfg is not None:\n            kernel_size = feat_transform_cfg.pop('kernel_size', 1)\n            self.feat_transform = ConvModule(\n                in_channels,\n                in_channels,\n                kernel_size,\n                stride=feat_gather_stride,\n                padding=int(feat_gather_stride // 2),\n                **feat_transform_cfg)\n        else:\n            self.feat_transform = None\n\n        if self.with_ffn:\n            self.ffn = FFN(\n                in_channels,\n                feedforward_channels,\n                num_ffn_fcs,\n                act_cfg=ffn_act_cfg,\n                dropout=dropout)\n            self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n        self.cls_fcs = nn.ModuleList()\n        for _ in range(num_cls_fcs):\n            self.cls_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.cls_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.cls_fcs.append(build_activation_layer(act_cfg))\n\n        if self.loss_cls.use_sigmoid:\n            self.fc_cls = nn.Linear(in_channels, self.num_classes)\n        else:\n            self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)\n\n        self.mask_fcs = nn.ModuleList()\n        for _ in range(num_mask_fcs):\n            self.mask_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.mask_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.mask_fcs.append(build_activation_layer(act_cfg))\n\n        self.fc_mask = nn.Linear(in_channels, out_channels)\n\n        self.previous = previous\n        self.previous_type = previous_type\n        self.previous_link = previous_link\n        self.previous_x_feat = previous_x_feat\n        self.previous_detach = previous_detach\n        self.previous_detach_link = previous_detach_link\n        self.previous_link_detach = previous_link_detach\n\n        if self.previous is not None:\n            _in_channels = self.in_channels\n            _conv_kernel_size = self.conv_kernel_size\n            _num_head = 8\n            _dropout = 0.\n            # tracking embedding\n            if self.previous_type == \"ffn\":\n                self.attention_previous = MultiheadAttention(\n                    _in_channels * _conv_kernel_size ** 2,\n                    _num_head,\n                    _dropout,\n                )\n                _, self.attention_previous_norm = build_norm_layer(\n                    dict(type='LN'),\n                    _in_channels * _conv_kernel_size ** 2\n                )\n                # add link ffn\n                self.link_ffn = FFN(\n                    in_channels,\n                    feedforward_channels,\n                    num_ffn_fcs,\n                    act_cfg=ffn_act_cfg,\n                    dropout=dropout)\n                self.link_ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n            elif self.previous_type == \"update\" or self.previous_type == \"update_obj\":\n\n                self.attention_previous_update_track = build_transformer_layer(kernel_updator_cfg)\n\n                self.attention_previous_track = MultiheadAttention(\n                    _in_channels * _conv_kernel_size ** 2,\n                    _num_head,\n                    _dropout,\n                )\n                _, self.attention_previous_norm_track = build_norm_layer(\n                    dict(type='LN'),\n                    _in_channels * _conv_kernel_size ** 2\n                )\n                # add link ffn\n                self.link_ffn_track = FFN(\n                    in_channels,\n                    feedforward_channels,\n                    num_ffn_fcs,\n                    act_cfg=ffn_act_cfg,\n                    dropout=dropout)\n                self.link_ffn_norm_track = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n            # seg and cls embedding Link\n            if self.previous_link == \"update_dynamic_cov\":\n                _in_channels = self.in_channels\n                _conv_kernel_size = self.conv_kernel_size\n                _num_head = 8\n                _dropout = 0.\n                self.attention_previous_update_link = build_transformer_layer(kernel_updator_cfg)\n                self.attention_previous_link = MultiheadAttention(\n                    _in_channels * _conv_kernel_size ** 2,\n                    _num_head,\n                    _dropout,\n                )\n                _, self.attention_previous_norm_link = build_norm_layer(\n                    dict(type='LN'),\n                    _in_channels * _conv_kernel_size ** 2\n                )\n                # add link ffn\n                self.link_ffn_link = FFN(\n                    in_channels,\n                    feedforward_channels,\n                    num_ffn_fcs,\n                    act_cfg=ffn_act_cfg,\n                    dropout=dropout)\n                self.link_ffn_norm_link = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n            elif self.previous_link == \"link_atten\":\n                _in_channels = self.in_channels\n                _conv_kernel_size = self.conv_kernel_size\n                _num_head = 8\n                _dropout = 0.\n                self.attention_previous_link = MultiheadAttention(\n                    _in_channels * _conv_kernel_size ** 2,\n                    _num_head,\n                    _dropout,\n                )\n                _, self.attention_previous_norm_link = build_norm_layer(\n                    dict(type='LN'),\n                    _in_channels * _conv_kernel_size ** 2\n                )\n                # add link ffn\n                self.link_ffn_link = FFN(\n                    in_channels,\n                    feedforward_channels,\n                    num_ffn_fcs,\n                    act_cfg=ffn_act_cfg,\n                    dropout=dropout)\n                self.link_ffn_norm_link = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n    def init_weights(self):\n        \"\"\"Use xavier initialization for all weight parameter and set\n        classification head bias as a specific value when use focal loss.\"\"\"\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n            else:\n                # adopt the default initialization for\n                # the weight and bias of the layer norm\n                pass\n        if self.loss_cls.use_sigmoid:\n            bias_init = bias_init_with_prob(0.01)\n            nn.init.constant_(self.fc_cls.bias, bias_init)\n        if self.kernel_init:\n            logger = get_root_logger()\n            logger.info(\n                'mask kernel in mask head is normal initialized by std 0.01')\n            nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)\n\n    def forward(self,\n                x,\n                proposal_feat,\n                mask_preds,\n                prev_cls_score=None,\n                mask_shape=None,\n                img_metas=None,\n                previous_obj_feats=None,\n                previous_mask_preds=None,\n                previous_x_feats=None\n                ):\n\n        N, num_proposals = proposal_feat.shape[:2]\n        if self.feat_transform is not None:\n            x = self.feat_transform(x)\n            if previous_x_feats is not None:\n                previous_x_feats = self.feat_transform(previous_x_feats)\n        C, H, W = x.shape[-3:]\n\n        mask_h, mask_w = mask_preds.shape[-2:]\n        if mask_h != H or mask_w != W:\n            gather_mask = F.interpolate(\n                mask_preds, (H, W), align_corners=False, mode='bilinear')\n        else:\n            gather_mask = mask_preds\n\n        sigmoid_masks = gather_mask.sigmoid()\n        nonzero_inds = sigmoid_masks > self.hard_mask_thr\n        sigmoid_masks = nonzero_inds.float()\n\n        # einsum is faster than bmm by 30%\n        x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)\n\n        # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]\n        proposal_feat = proposal_feat.reshape(N, num_proposals,\n                                              self.in_channels,\n                                              -1).permute(0, 1, 3, 2)\n\n        # whether to detach the previous outputs\n        if self.training and self.previous_detach:\n            previous_obj_feats = previous_obj_feats.detach()\n\n        # update previous with link object query\n        if previous_obj_feats is not None and self.previous_link == \"update_dynamic_cov\":\n            previous_obj_feats_link = previous_obj_feats.reshape(N, num_proposals,\n                                                                 self.in_channels,\n                                                                 -1).permute(0, 1, 3, 2)\n\n            if self.training and self.previous_detach_link:\n                previous_obj_feats_link = previous_obj_feats_link.detach()\n\n            previous_obj_feats_update = self.attention_previous_update_link(x_feat, previous_obj_feats_link)\n\n            previous_obj_feats_update = previous_obj_feats_update.reshape(N, num_proposals, -1).permute(1, 0, 2)\n            cur_obj_feat = proposal_feat.reshape(N, num_proposals, self.in_channels * self.conv_kernel_size ** 2). \\\n                permute(1, 0, 2)\n            cur_obj_feat = self.attention_previous_norm_link(\n                self.attention_previous_link(\n                    query=cur_obj_feat,\n                    key=previous_obj_feats_update,\n                    value=previous_obj_feats_update,\n                    identity=cur_obj_feat\n                ),\n            )\n            cur_obj_feat = cur_obj_feat.permute(1, 0, 2)\n            cur_obj_feat = cur_obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n            # pre_obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n            proposal_feat = self.link_ffn_norm_link(self.link_ffn_link(cur_obj_feat))\n\n        if previous_obj_feats is not None and self.previous_link == \"link_atten\":\n            previous_obj_feats_link = previous_obj_feats.reshape(N, num_proposals,\n                                                                 self.in_channels,\n                                                                 -1).permute(0, 1, 3, 2)\n\n            previous_obj_feats_update = previous_obj_feats_link.reshape(N, num_proposals, -1).permute(1, 0, 2)\n            cur_obj_feat = proposal_feat.reshape(N, num_proposals, self.in_channels * self.conv_kernel_size ** 2). \\\n                permute(1, 0, 2)\n            cur_obj_feat = self.attention_previous_norm_link(\n                self.attention_previous_link(\n                    query=cur_obj_feat,\n                    key=previous_obj_feats_update,\n                    value=previous_obj_feats_update,\n                    identity=cur_obj_feat\n                ),\n            )\n            cur_obj_feat = cur_obj_feat.permute(1, 0, 2)\n            cur_obj_feat = cur_obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n            # pre_obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n            proposal_feat = self.link_ffn_norm_link(self.link_ffn_link(cur_obj_feat))\n\n        # update current\n        obj_feat = self.kernel_update_conv(x_feat, proposal_feat)\n\n        # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]\n        obj_feat = obj_feat.reshape(N, num_proposals,\n                                    -1).permute(1, 0, 2)\n        obj_feat = self.attention_norm(self.attention(obj_feat))\n        # [N, B, K*K*C] -> [B, N, K*K*C]\n        obj_feat = obj_feat.permute(1, 0, 2)\n\n        # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n        obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n\n        # FFN\n        if self.with_ffn:\n            obj_feat = self.ffn_norm(self.ffn(obj_feat))\n\n        # For Tracking Parts\n        # Link previous and cur if previous obj feat is Not None\n        if previous_obj_feats is not None:\n            # previous_obj_feats (b, n, c, k, k) -> (b,n,c,k*k) -> (b,,n, k*k, c)\n            # permute to correct dimension\n\n            if self.previous_type == \"ffn\":\n                previous_obj_feats = previous_obj_feats.reshape(N, num_proposals,\n                                                                self.in_channels,\n                                                                -1).permute(0, 1, 3, 2)\n                cur_obj_feat = obj_feat.reshape(N, num_proposals, self.in_channels * self.conv_kernel_size ** 2). \\\n                    permute(1, 0, 2)\n                previous_obj_feats = previous_obj_feats.reshape(N, num_proposals,\n                                                                self.in_channels * self.conv_kernel_size ** 2).permute(\n                    1, 0, 2)\n\n                previous_obj_feat = self.attention_previous_norm(\n                    self.attention_previous(\n                        query=cur_obj_feat,\n                        key=previous_obj_feats,\n                        value=previous_obj_feats,\n                        identity=cur_obj_feat\n                    ),\n                )\n                previous_obj_feat = previous_obj_feat.permute(1, 0, 2)\n                previous_obj_feat_track = previous_obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n                # pre_obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n                previous_obj_feat_track = self.link_ffn_norm(self.link_ffn(previous_obj_feat_track))\n\n            elif self.previous_type == \"update\":\n                # not work\n                previous_obj_feats = previous_obj_feats.reshape(N, num_proposals,\n                                                                self.in_channels,\n                                                                -1).permute(0, 1, 3, 2)\n                previous_obj_feats_track = self.attention_previous_update_track(x_feat, previous_obj_feats)\n\n                previous_obj_feats_track = previous_obj_feats_track.reshape(N, num_proposals,\n                                                                            self.in_channels,\n                                                                            -1).permute(0, 1, 3, 2)\n                cur_obj_feat = obj_feat.reshape(N, num_proposals, self.in_channels * self.conv_kernel_size ** 2). \\\n                    permute(1, 0, 2)\n                previous_obj_feats_track = previous_obj_feats_track.reshape(N, num_proposals,\n                                                                            self.in_channels * self.conv_kernel_size ** 2).permute(\n                    1, 0, 2)\n\n                previous_obj_feats_track = self.attention_previous_norm_track(\n                    self.attention_previous_track(\n                        query=cur_obj_feat,\n                        key=previous_obj_feats_track,\n                        value=previous_obj_feats_track,\n                        identity=cur_obj_feat\n                    ),\n                )\n                previous_obj_feats_track = previous_obj_feats_track.permute(1, 0, 2)\n                previous_obj_feats_track = previous_obj_feats_track.reshape(N, num_proposals, -1, self.in_channels)\n                # pre_obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n                previous_obj_feat_track = self.link_ffn_norm_track(self.link_ffn_track(previous_obj_feats_track))\n\n            elif self.previous_type == \"update_obj\":\n                # not work\n                previous_obj_feats = previous_obj_feats.reshape(N, num_proposals,\n                                                                self.in_channels,\n                                                                -1).permute(0, 1, 3, 2)\n                previous_obj_feats_track = self.attention_previous_update_track(obj_feat.squeeze(2), previous_obj_feats)\n\n                previous_obj_feats_track = previous_obj_feats_track.reshape(N, num_proposals,\n                                                                            self.in_channels,\n                                                                            -1).permute(0, 1, 3, 2)\n                cur_obj_feat = obj_feat.reshape(N, num_proposals, self.in_channels * self.conv_kernel_size ** 2). \\\n                    permute(1, 0, 2)\n                previous_obj_feats_track = previous_obj_feats_track.reshape(N, num_proposals,\n                                                                            self.in_channels * self.conv_kernel_size ** 2).permute(\n                    1, 0, 2)\n\n                previous_obj_feats_track = self.attention_previous_norm_track(\n                    self.attention_previous_track(\n                        query=cur_obj_feat,\n                        key=previous_obj_feats_track,\n                        value=previous_obj_feats_track,\n                        identity=cur_obj_feat\n                    ),\n                )\n                previous_obj_feats_track = previous_obj_feats_track.permute(1, 0, 2)\n                previous_obj_feats_track = previous_obj_feats_track.reshape(N, num_proposals, -1, self.in_channels)\n                # pre_obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n                previous_obj_feat_track = self.link_ffn_norm_track(self.link_ffn_track(previous_obj_feats_track))\n            else:\n                previous_obj_feat_track = None\n\n        cls_feat = obj_feat.sum(-2)\n        mask_feat = obj_feat\n\n        for cls_layer in self.cls_fcs:\n            cls_feat = cls_layer(cls_feat)\n        for reg_layer in self.mask_fcs:\n            mask_feat = reg_layer(mask_feat)\n\n        cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)\n        # [B, N, K*K, C] -> [B, N, C, K*K]\n        mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)\n\n        if (self.mask_transform_stride == 2\n                and self.feat_gather_stride == 1):\n            mask_x = F.interpolate(\n                x, scale_factor=0.5, mode='bilinear', align_corners=False)\n            H, W = mask_x.shape[-2:]\n        else:\n            mask_x = x\n        # group conv is 5x faster than unfold and uses about 1/5 memory\n        # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms\n        # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369\n        # fold_x = F.unfold(\n        #     mask_x,\n        #     self.conv_kernel_size,\n        #     padding=int(self.conv_kernel_size // 2))\n        # mask_feat = mask_feat.reshape(N, num_proposals, -1)\n        # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)\n        # [B, N, C, K*K] -> [B*N, C, K, K]\n        mask_feat = mask_feat.reshape(N, num_proposals, C,\n                                      self.conv_kernel_size,\n                                      self.conv_kernel_size)\n        # [B, C, H, W] -> [1, B*C, H, W]\n        new_mask_preds = []\n        for i in range(N):\n            new_mask_preds.append(\n                F.conv2d(\n                    mask_x[i:i + 1],\n                    mask_feat[i],\n                    padding=int(self.conv_kernel_size // 2)))\n\n        new_mask_preds = torch.cat(new_mask_preds, dim=0)\n        new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)\n        if self.mask_transform_stride == 2:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                scale_factor=2,\n                mode='bilinear',\n                align_corners=False)\n\n        if mask_shape is not None and mask_shape[0] != H:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                mask_shape,\n                align_corners=False,\n                mode='bilinear')\n\n        if previous_obj_feats is not None and previous_obj_feat_track is not None:\n            return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(\n                N, num_proposals, self.in_channels, self.conv_kernel_size, self.conv_kernel_size), x_feat, \\\n                   previous_obj_feat_track.permute(0, 1, 3, 2).reshape(\n                       N, num_proposals, self.in_channels, self.conv_kernel_size, self.conv_kernel_size)\n        else:\n            return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(\n                N, num_proposals, self.in_channels, self.conv_kernel_size, self.conv_kernel_size), x_feat, None\n\n    @force_fp32(apply_to=('cls_score', 'mask_pred'))\n    def loss(self,\n             object_feats,\n             cls_score,\n             mask_pred,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             imgs_whwh=None,\n             reduction_override=None,\n             **kwargs):\n\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_pos = pos_inds.sum().float()\n        avg_factor = reduce_mean(num_pos).clamp_(min=1.0)\n\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n        assert mask_pred.shape[0] == cls_score.shape[0]\n        assert mask_pred.shape[1] == cls_score.shape[1]\n\n        if cls_score is not None:\n            if cls_score.numel() > 0:\n                losses['loss_cls'] = self.loss_cls(\n                    cls_score.view(num_preds, -1),\n                    labels,\n                    label_weights,\n                    avg_factor=avg_factor,\n                    reduction_override=reduction_override)\n                losses['pos_acc'] = accuracy(\n                    cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds])\n        if mask_pred is not None:\n            bool_pos_inds = pos_inds.type(torch.bool)\n            # 0~self.num_classes-1 are FG, self.num_classes is BG\n            # do not perform bounding box regression for BG anymore.\n            H, W = mask_pred.shape[-2:]\n            if pos_inds.any():\n                pos_mask_pred = mask_pred.reshape(num_preds, H,\n                                                  W)[bool_pos_inds]\n                pos_mask_targets = mask_targets[bool_pos_inds]\n                losses['loss_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n                losses['loss_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n                if self.loss_rank is not None:\n                    batch_size = mask_pred.size(0)\n                    rank_target = mask_targets.new_full((batch_size, H, W),\n                                                        self.ignore_label,\n                                                        dtype=torch.long)\n                    rank_inds = pos_inds.view(batch_size,\n                                              -1).nonzero(as_tuple=False)\n                    batch_mask_targets = mask_targets.view(\n                        batch_size, -1, H, W).bool()\n                    for i in range(batch_size):\n                        curr_inds = (rank_inds[:, 0] == i)\n                        curr_rank = rank_inds[:, 1][curr_inds]\n                        for j in curr_rank:\n                            rank_target[i][batch_mask_targets[i][j]] = j\n                    losses['loss_rank'] = self.loss_rank(\n                        mask_pred, rank_target, ignore_index=self.ignore_label)\n            else:\n                losses['loss_mask'] = mask_pred.sum() * 0\n                losses['loss_dice'] = mask_pred.sum() * 0\n                if self.loss_rank is not None:\n                    losses['loss_rank'] = mask_pred.sum() * 0\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples,),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros((num_samples, self.num_classes))\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            pos_mask_targets = pos_gt_mask\n            mask_targets[pos_inds, ...] = pos_mask_targets\n            mask_weights[pos_inds, ...] = 1\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            sem_labels = pos_mask.new_full((self.num_stuff_classes,),\n                                           self.num_classes,\n                                           dtype=torch.long)\n            sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_stuff_weights = torch.eye(\n                self.num_stuff_classes, device=pos_mask.device)\n            sem_thing_weights = pos_mask.new_zeros(\n                (self.num_stuff_classes, self.num_thing_classes))\n            sem_label_weights = torch.cat(\n                [sem_thing_weights, sem_stuff_weights], dim=-1)\n            if len(gt_sem_cls > 0):\n                sem_inds = gt_sem_cls - self.num_thing_classes\n                sem_inds = sem_inds.long()\n                sem_labels[sem_inds] = gt_sem_cls.long()\n                sem_targets[sem_inds] = gt_sem_seg\n                sem_weights[sem_inds] = 1\n\n            label_weights[:, self.num_thing_classes:] = 0\n            labels = torch.cat([labels, sem_labels])\n            label_weights = torch.cat([label_weights, sem_label_weights])\n            mask_targets = torch.cat([mask_targets, sem_targets])\n            mask_weights = torch.cat([mask_weights, sem_weights])\n\n        return labels, label_weights, mask_targets, mask_weights\n\n    def get_targets(self,\n                    sampling_results,\n                    gt_mask,\n                    gt_labels,\n                    rcnn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None\n                    ):\n\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * 2\n            gt_sem_cls = [None] * 2\n\n        labels, label_weights, mask_targets, mask_weights = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rcnn_train_cfg)\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n        return labels, label_weights, mask_targets, mask_weights\n\n    def rescale_masks(self, masks_per_img, img_meta):\n        h, w, _ = img_meta['img_shape']\n        masks_per_img = F.interpolate(\n            masks_per_img.unsqueeze(0).sigmoid(),\n            size=img_meta['batch_input_shape'],\n            mode='bilinear',\n            align_corners=False)\n\n        masks_per_img = masks_per_img[:, :, :h, :w]\n        ori_shape = img_meta['ori_shape']\n        seg_masks = F.interpolate(\n            masks_per_img,\n            size=ori_shape[:2],\n            mode='bilinear',\n            align_corners=False).squeeze(0)\n        return seg_masks\n\n    def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img,\n                      test_cfg, img_meta):\n        # resize mask predictions back\n        seg_masks = self.rescale_masks(masks_per_img, img_meta)\n        seg_masks = seg_masks > test_cfg.mask_thr\n        bbox_result, segm_result, mask_preds = self.segm2result(seg_masks, labels_per_img,\n                                                                scores_per_img)\n        return bbox_result, segm_result, mask_preds\n\n    def segm2result(self, mask_preds, det_labels, cls_scores):\n        num_classes = self.num_classes\n        bbox_result = None\n        segm_result = [[] for _ in range(num_classes)]\n        det_labels = det_labels.cpu().numpy()\n        cls_scores = cls_scores.cpu().numpy()\n        num_ins = mask_preds.shape[0]\n        # fake bboxes mask to bboxes\n        bboxes = np.zeros((num_ins, 5), dtype=np.float32)\n        bboxes[:, -1] = cls_scores\n        bboxes[:, :4] = np.array(tensor_mask2box(mask_preds).clip(min=0))\n        # mask_preds = mask_preds.cpu().numpy()\n        # bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)]\n        for idx in range(num_ins):\n            segm_result[det_labels[idx]].append(mask_preds[idx])\n        return bboxes, segm_result, mask_preds"
  },
  {
    "path": "knet/video/knet.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import TwoStageDetector, BaseDetector\nfrom mmdet.models.builder import build_head\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes\n\n\n@DETECTORS.register_module()\nclass VideoKNet(TwoStageDetector):\n\n    def __init__(self,\n                 *args,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 **kwargs):\n        super(VideoKNet, self).__init__(*args, **kwargs)\n        assert self.with_rpn, 'KNet does not support external proposals'\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        super(TwoStageDetector, self).forward_train(img, img_metas)\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                    i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                    i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        gt_masks = gt_masks_tensor\n\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        losses = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        losses.update(rpn_losses)\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        segm_results = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            imgs_whwh=None,\n            rescale=rescale)\n        return segm_results\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        # roi_head\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x"
  },
  {
    "path": "knet/video/knet_quansi_dense.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone\nfrom mmdet.core import build_assigner, build_sampler\nfrom knet.video.qdtrack.builder import build_tracker\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes\nfrom unitrack.mask import tensor_mask2box\n\n@DETECTORS.register_module()\nclass VideoKNetQuansiTrack(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 track_localization_fpn=None,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 track_train_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 kitti_step=False,\n                 fix_knet=False,\n                 freeze_detector=False,\n                 semantic_filter=False,\n                 **kwargs):\n        super(VideoKNetQuansiTrack, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_train_cfg = track_train_cfg\n            self.track_head = build_head(track_head)\n            self.init_track_assigner_sampler()\n\n        if track_localization_fpn is not None:\n            self.track_localization_fpn = build_neck(track_localization_fpn)\n\n        if tracker is not None:\n            self.tracker_cfg = tracker\n\n        if freeze_detector:\n           self._freeze_detector()\n\n        if fix_knet:\n            for p in self.backbone.parameters():\n                p.requires_grad_(False)\n            self.backbone.eval()\n            for p in self.neck.parameters():\n                p.requires_grad_(False)\n            self.neck.eval()\n            for p in self.rpn_head.parameters():\n                p.requires_grad_(False)\n            self.rpn_head.eval()\n            for p in self.roi_head.parameters():\n                p.requires_grad_(False)\n            self.roi_head.eval()\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.num_proposals = self.rpn_head.num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step\n        self.semantic_filter = semantic_filter\n\n    def init_tracker(self):\n        self.tracker = build_tracker(self.tracker_cfg)\n\n    def _freeze_detector(self):\n\n        self.detector = [\n            self.rpn_head, self.roi_head\n        ]\n        for model in self.detector:\n            model.eval()\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def init_track_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler.\"\"\"\n\n        self.track_roi_assigner = build_assigner(\n            self.track_train_cfg.assigner)\n        self.track_share_assigner = False\n\n        self.track_roi_sampler = build_sampler(\n            self.track_train_cfg.sampler, context=self)\n        self.track_share_sampler = False\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # prepare the gt_match_indices\n        gt_pids_list =[]\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n\n        gt_match_indices = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_gt_labels, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        self.backbone.eval()\n        with torch.no_grad():\n            x_ref = self.extract_feat(ref_img)\n        self.backbone.train()\n\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        # simple forward to get the reference results\n        ref_rpn_results = self.rpn_head.simple_test_rpn(x_ref, ref_img_metas_new)\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores, ref_seg_preds) = ref_rpn_results\n\n        # forward to get the current results\n        losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        # simple forward to get the reference results\n        _,  ref_cls_scores, ref_mask_preds, ref_scaled_mask_preds = self.roi_head.simple_test_mask_preds(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas_new,\n           )\n\n        # ===== Tracking Part -==== #\n        # assign both key frame and reference frame tracking targets\n        key_sampling_results, ref_sampling_results = [], []\n        num_imgs = len(img_metas)\n\n        x_track_fea = x_feats\n        x_track_fea_ref = ref_x_feats\n\n        for i in range(num_imgs):\n            assign_result = self.track_roi_assigner.assign(\n                scaled_mask_preds[i][:self.num_proposals].detach(), cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                gt_masks[i], gt_labels[i], img_meta=img_metas[i])\n            sampling_result = self.track_roi_sampler.sample(\n                assign_result,\n                mask_preds[i][:self.num_proposals].detach(),\n                gt_masks[i])\n            key_sampling_results.append(sampling_result)\n\n            ref_assign_result = self.track_roi_assigner.assign(\n                ref_scaled_mask_preds[i][:self.num_proposals].detach(), ref_cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                ref_gt_masks[i], ref_gt_labels[i], img_meta=ref_img_metas_new[i])\n            ref_sampling_result = self.track_roi_sampler.sample(\n                ref_assign_result,\n                ref_mask_preds[i][:self.num_proposals].detach(),\n                ref_gt_masks[i])\n            ref_sampling_results.append(ref_sampling_result)\n\n        # mask feature embeddings\n        key_masks = [res.pos_gt_masks for res in key_sampling_results]\n        key_feats = self._track_forward(x_track_fea, key_masks)\n        ref_masks = [res.pos_gt_masks for res in ref_sampling_results]\n        ref_feats = self._track_forward(x_track_fea_ref, ref_masks)\n\n        match_feats = self.track_head.match(key_feats, ref_feats,\n                                            key_sampling_results,\n                                            ref_sampling_results)\n\n        asso_targets = self.track_head.get_track_targets(\n            gt_match_indices, key_sampling_results, ref_sampling_results)\n        loss_track = self.track_head.loss(*match_feats, *asso_targets)\n\n        losses.update(loss_track)\n        losses.update(rpn_losses)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None, **kwargs):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n\n        # if ref_img is not None:\n        #     ref_img = ref_img[0]\n\n        # whether is the first frame for such clips\n        assert 'city' in img_metas[0]['filename'] and 'iid' in img_metas[0]\n        iid = img_metas[0]['iid']\n        fid = iid % 10000\n        img_name = img_metas[0]['filename'].split(\"/\")[-1].split(\".\")[0]\n\n        if \"city\" in img_metas[0]['filename']:\n            iid = img_metas[0]['iid']\n            fid = iid % 10000\n            is_first = (fid == 1)\n        else:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 0)\n\n        if is_first:\n            self.init_tracker()\n\n        # for current frame\n        x = self.extract_feat(img)\n        # x_track_fea = self.track_localization_fpn(x)\n        # current frame inference\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        x_track_fea = x_feats\n        cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas)\n        # for tracking part\n        sorted_bbox_result, segm_result, mask_preds, panoptic_result = cur_segm_results[0]\n        panoptic_seg, segments_info = panoptic_result\n\n        # get the semantic filter\n        if self.semantic_filter:\n            seg_preds = torch.nn.functional.interpolate(seg_preds, panoptic_seg.shape, mode='bilinear', align_corners=False)\n            seg_preds = seg_preds.sigmoid()\n            seg_out = seg_preds.argmax(1)\n            semantic_thing = (seg_out < self.num_thing_classes).to(dtype=torch.float32)\n        else:\n            semantic_thing = 1.\n\n        # get sorted tracking thing ids, labels, masks, score for tracking\n        things_index_for_tracking, things_labels_for_tracking, thing_masks_for_tracking, things_score_for_tracking = \\\n            self.get_things_id_for_tracking(panoptic_seg, segments_info)\n        things_labels_for_tracking = torch.Tensor(things_labels_for_tracking).to(cls_scores.device).long()\n        if len(things_labels_for_tracking) > 0:\n            thing_masks_for_tracking_final = []\n            for mask in thing_masks_for_tracking:\n                thing_masks_for_tracking_final.append(torch.Tensor(mask).unsqueeze(0).to(\n                    x_feats.device).float())\n            thing_masks_for_tracking_final = torch.cat(thing_masks_for_tracking_final, 0)\n            thing_masks_for_tracking = thing_masks_for_tracking_final\n            thing_masks_for_tracking_scaled = F.interpolate(thing_masks_for_tracking.unsqueeze(0),\n                                                     size=x_track_fea.size()[2:], mode=\"bilinear\", align_corners=False)\n            things_bbox_for_tracking = torch.zeros((len(things_score_for_tracking), 5),\n                                                   dtype=torch.float, device=x_feats.device)\n            things_bbox_for_tracking[:, 4] = torch.tensor(things_score_for_tracking,\n                                                          device=things_bbox_for_tracking.device)\n            thing_masks_for_tracking_with_semantic_filter = thing_masks_for_tracking_scaled * semantic_thing\n\n        if len(things_labels_for_tracking) == 0:\n            track_feats = None\n        else:\n            # tracking embedding features\n            track_feats = self._track_forward(x_track_fea, thing_masks_for_tracking_with_semantic_filter)\n\n        if track_feats is not None:\n            # assert len(things_id_for_tracking) == len(things_labels_for_tracking)\n            things_bbox_for_tracking[:, :4] = torch.tensor(\n                tensor_mask2box(thing_masks_for_tracking_with_semantic_filter),\n                device=things_bbox_for_tracking.device)\n\n            bboxes, labels, ids = self.tracker.match(\n                bboxes=things_bbox_for_tracking,\n                labels=things_labels_for_tracking,\n                track_feats=track_feats,\n                frame_id=fid)\n            ids = ids + 1\n            # hack for unmatched into background\n            ids[ids == -1] = 0\n        else:\n            ids = []\n\n        track_maps = self.generate_track_id_maps(ids, thing_masks_for_tracking, panoptic_seg)\n\n        return self.get_semantic_seg(panoptic_seg, segments_info), track_maps, None, None, None\n\n    def _track_forward(self, x, mask_pred):\n        \"\"\"Track head forward function used in both training and testing.\n        We use mask pooling to get the fine grain features\"\"\"\n        track_feats_list = []\n\n        for i, masks in enumerate(mask_pred):\n            masks = masks.sigmoid() > 0.5\n            masks = masks.float().detach()\n            size = x.size()[2:]\n            masks = F.interpolate(masks.unsqueeze(0), size=size, mode=\"bilinear\", align_corners=True).squeeze(0)\n            track_feats = torch.einsum('nhw,chw->nc', masks, x[i])\n            track_feats = track_feats / (masks.sum(-1).sum(-1) + 1).unsqueeze(-1)\n            track_feats_list.append(track_feats)\n        track_feats = torch.cat(track_feats_list, 0)\n        track_feats = self.track_head(track_feats)\n        return track_feats\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def get_things_id_for_tracking(self, panoptic_seg, seg_infos):\n        idxs = []\n        labels = []\n        masks = []\n        score = []\n        for segment in seg_infos:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                idxs.append(segment[\"instance_id\"])\n                labels.append(segment['category_id'])\n                score.append(segment['score'])\n        return idxs, labels, masks, score\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred = torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg\n\n    def generate_track_id_maps(self, ids, masks, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        if len(ids) == 0:\n            return final_id_maps\n        # assert len(things_mask_results) == len(track_results)\n        masks = masks.bool()\n        for i, id in enumerate(ids):\n            mask = masks[i].cpu().numpy()\n            final_id_maps[mask] = id\n        return final_id_maps\n\n\nimport cv2\nimport numpy as np\nimport os.path as osp\n\n\ndef log_masks_for_inference(masks_preds, names, output_dirs=\"work_dirs/vps/vps_output/thing_masks\"):\n    for i, masks in enumerate(masks_preds):\n        out_masks = np.zeros(masks_preds[0].shape).astype(np.int16)\n        masks = masks.sigmoid() > 0.5\n        masks = masks.cpu().numpy()\n        out_masks[masks==1] = 255\n        file_name = osp.join(output_dirs, names + \"_\" + str(i) + \".png\")\n        cv2.imwrite(file_name, out_masks)"
  },
  {
    "path": "knet/video/knet_quansi_dense_embed_fc.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob,\n                      build_activation_layer, build_norm_layer)\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone, build_roi_extractor\nfrom mmdet.core import build_assigner, build_sampler\nfrom knet.video.qdtrack.builder import build_tracker\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes, sem2ins_masks_kitti_step\nfrom unitrack.mask import tensor_mask2box\n\n\n@DETECTORS.register_module()\nclass VideoKNetQuansiEmbedFC(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 track_mhsa=False,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 track_train_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 detach_mask_emd=False,\n                 cityscapes=False,\n                 kitti_step=False,\n                 cityscapes_short=False,\n                 freeze_detector=False,\n                 semantic_filter=True,\n                 # linking parameters\n                 link_previous=False,\n                 bbox_roi_extractor=None,\n                 **kwargs):\n        super(VideoKNetQuansiEmbedFC, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_train_cfg = track_train_cfg\n            self.track_head = build_head(track_head)\n            self.init_track_assigner_sampler()\n            if bbox_roi_extractor is not None:\n                self.track_roi_extractor = build_roi_extractor(\n                    bbox_roi_extractor)\n\n        if tracker is not None:\n            self.tracker_cfg = tracker\n\n        if freeze_detector:\n            self._freeze_detector()\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.num_proposals = self.rpn_head.num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step  # whether to train the kitti step panoptic segmentation\n        self.cityscapes_short = cityscapes_short  # whether to test with short clips (300)\n\n        self.semantic_filter = semantic_filter\n        self.link_previous = link_previous\n        self.detach_mask_emd = detach_mask_emd\n        self.track_mhsa = track_mhsa\n        # add embedding fcs for the final stage queries\n        num_emb_fcs = 1\n        act_cfg = dict(type='ReLU', inplace=True)\n        in_channels = 256\n        out_channels = 256\n        self.embed_fcs = nn.ModuleList()\n        for _ in range(num_emb_fcs):\n            self.embed_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.embed_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.embed_fcs.append(build_activation_layer(act_cfg))\n\n        self.fc_embed = nn.Linear(in_channels, out_channels)\n\n    def init_tracker(self):\n        self.tracker = build_tracker(self.tracker_cfg)\n\n    def _freeze_detector(self):\n\n        self.detector = [\n            self.rpn_head, self.roi_head\n        ]\n        for model in self.detector:\n            model.eval()\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def init_track_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler.\"\"\"\n\n        self.track_roi_assigner = build_assigner(\n            self.track_train_cfg.assigner)\n        self.track_share_assigner = False\n\n        self.track_roi_sampler = build_sampler(\n            self.track_train_cfg.sampler, context=self)\n        self.track_share_sampler = False\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=list(range(self.num_stuff_classes,\n                                                      self.num_thing_classes + self.num_stuff_classes))\n                    )\n                elif self.kitti_step:\n                    sem_labels, sem_seg = sem2ins_masks_kitti_step(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=2,\n                        thing_label_in_seg=(11, 13))\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        img_h, img_w = batch_input_shape\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:, 1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # prepare the gt_match_indices\n        gt_pids_list = []\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n\n        gt_match_indices = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                                ref_masks_gt, ref_gt_labels,\n                                                                                ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        # simple forward to get the reference results\n        self.rpn_head.eval()\n        ref_rpn_results = self.rpn_head.simple_test_rpn(x_ref, ref_img_metas_new)\n        self.rpn_head.train()\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores, ref_seg_preds) = ref_rpn_results\n\n        ref_obj_feats, ref_cls_scores, ref_mask_preds, ref_scaled_mask_preds = self.roi_head.simple_test_mask_preds(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas_new,\n        )\n\n        if self.link_previous:\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds, object_feats_track = self.roi_head.forward_train_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None,\n                previous_obj_feats=ref_obj_feats,\n                previous_mask_preds=ref_scaled_mask_preds,\n                previous_x_feats=ref_x_feats,\n            )\n        else:\n            # forward to get the current results\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None)\n\n        # ===== Tracking Part -==== #\n        # assign both key frame and reference frame tracking targets\n        key_sampling_results, ref_sampling_results = [], []\n        num_imgs = len(img_metas)\n\n        for i in range(num_imgs):\n            assign_result = self.track_roi_assigner.assign(\n                scaled_mask_preds[i][:self.num_proposals].detach(),\n                cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                gt_masks[i], gt_labels[i], img_meta=img_metas[i])\n            sampling_result = self.track_roi_sampler.sample(\n                assign_result,\n                mask_preds[i][:self.num_proposals].detach(),\n                gt_masks[i])\n            key_sampling_results.append(sampling_result)\n\n            ref_assign_result = self.track_roi_assigner.assign(\n                ref_scaled_mask_preds[i][:self.num_proposals].detach(),\n                ref_cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                ref_gt_masks[i], ref_gt_labels[i], img_meta=ref_img_metas_new[i])\n            ref_sampling_result = self.track_roi_sampler.sample(\n                ref_assign_result,\n                ref_mask_preds[i][:self.num_proposals].detach(),\n                ref_gt_masks[i])\n            ref_sampling_results.append(ref_sampling_result)\n        if self.detach_mask_emd:\n            object_feats = object_feats.detach()\n            ref_obj_feats = ref_obj_feats.detach()\n\n        if self.link_previous:\n            object_feats = object_feats_track\n\n        N, num_proposal, _, _, _ = object_feats.shape\n        emb_feat = object_feats.squeeze(-2).squeeze(-1)[:, :self.num_proposals, ]\n\n        for emb_layer in self.embed_fcs:\n            emb_feat = emb_layer(emb_feat)\n        object_feats_embed = self.fc_embed(emb_feat).view(N, self.num_proposals, -1)\n\n        ref_emb_feat = ref_obj_feats.squeeze(-2).squeeze(-1)[:, :self.num_proposals, ]\n        for emb_layer in self.embed_fcs:\n            ref_emb_feat = emb_layer(ref_emb_feat)\n        ref_object_feats_embed = self.fc_embed(ref_emb_feat).view(N, self.num_proposals, -1)\n\n        # sampling predicted GT mask\n        key_emb_indexs = [res.pos_inds for res in key_sampling_results]\n        object_feats_embed_list = []\n        for i in range(len(key_emb_indexs)):\n            object_feats_embed_list.append(object_feats_embed[:, key_emb_indexs[i], :].squeeze(0))\n\n        key_feats = self._track_forward(object_feats_embed_list)\n\n        ref_emb_indexs = [res.pos_inds for res in ref_sampling_results]\n        ref_object_feats_embed_list = []\n        for i in range(len(ref_emb_indexs)):\n            ref_object_feats_embed_list.append(ref_object_feats_embed[:, ref_emb_indexs[i], :].squeeze(0))\n\n        ref_feats = self._track_forward(ref_object_feats_embed_list)\n\n        match_feats = self.track_head.match(key_feats, ref_feats,\n                                            key_sampling_results,\n                                            ref_sampling_results)\n\n        asso_targets = self.track_head.get_track_targets(\n            gt_match_indices, key_sampling_results, ref_sampling_results)\n        loss_track = self.track_head.loss(*match_feats, *asso_targets)\n\n        losses.update(loss_track)\n        losses.update(rpn_losses)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None, **kwargs):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n        # set the dataset type\n        # whether is the first frame for such clips\n        if self.cityscapes and not self.kitti_step and not self.cityscapes_short:\n            iid = img_metas[0]['iid']\n            fid = iid % 10000\n            is_first = (fid == 1)\n        else:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 0)\n\n        if is_first:\n            self.init_tracker()\n            self.obj_feats_memory = None\n            self.x_feats_memory = None\n            self.mask_preds_memory = None\n\n        # for current frame\n        x = self.extract_feat(img)\n        # current frame inference\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        if self.link_previous:\n            cur_segm_results, obj_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                previous_obj_feats=self.obj_feats_memory,\n                previous_mask_preds=self.mask_preds_memory,\n                previous_x_feats=self.x_feats_memory,\n                is_first=is_first,\n            )\n\n            self.obj_feats_memory = obj_feats\n            self.x_feats_memory = x_feats\n            self.mask_preds_memory = scaled_mask_preds\n        else:\n            cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas)\n\n        # for tracking part\n        _, segm_result, mask_preds, panoptic_result, query_output = cur_segm_results[0]\n        panoptic_seg, segments_info = panoptic_result\n\n        # get sorted tracking thing ids, labels, masks, score for tracking\n        things_index_for_tracking, things_labels_for_tracking, thing_masks_for_tracking, things_score_for_tracking = \\\n            self.get_things_id_for_tracking(panoptic_seg, segments_info)\n        things_labels_for_tracking = torch.Tensor(things_labels_for_tracking).to(cls_scores.device).long()\n\n        # get the semantic filter\n        if self.semantic_filter:\n            seg_preds = torch.nn.functional.interpolate(seg_preds, panoptic_seg.shape, mode='bilinear',\n                                                        align_corners=False)\n            seg_preds = seg_preds.sigmoid()\n            seg_out = seg_preds.argmax(1)\n            semantic_thing = (seg_out < self.num_thing_classes).to(dtype=torch.float32)\n        else:\n            semantic_thing = 1.\n\n        if len(things_labels_for_tracking) > 0:\n            things_bbox_for_tracking = torch.zeros((len(things_score_for_tracking), 5),\n                                                   dtype=torch.float, device=x_feats.device)\n            things_bbox_for_tracking[:, 4] = torch.tensor(things_score_for_tracking,\n                                                          device=things_bbox_for_tracking.device)\n\n            thing_masks_for_tracking_final = []\n            for mask in thing_masks_for_tracking:\n                thing_masks_for_tracking_final.append(torch.Tensor(mask).unsqueeze(0).to(\n                    x_feats.device).float())\n            thing_masks_for_tracking_final = torch.cat(thing_masks_for_tracking_final, 0)\n            thing_masks_for_tracking = thing_masks_for_tracking_final\n            thing_masks_for_tracking_with_semantic_filter = thing_masks_for_tracking_final * semantic_thing\n\n        if len(things_labels_for_tracking) == 0:\n            track_feats = None\n        else:\n            # tracking embeddings\n            N, _, _, _ = query_output.shape\n            emb_feat = query_output.squeeze(-2).squeeze(-1).unsqueeze(0)  # (n,d,1,1) -> (1,n,d)\n\n            for emb_layer in self.embed_fcs:\n                emb_feat = emb_layer(emb_feat)\n            object_feats_embed = self.fc_embed(emb_feat).view(1, N, -1)\n\n            object_feats_embed_for_tracking = object_feats_embed.squeeze(0)\n            # tracking embedding features\n            track_feats = self._track_forward([object_feats_embed_for_tracking])\n\n        if track_feats is not None:\n            things_bbox_for_tracking[:, :4] = torch.tensor(\n                tensor_mask2box(thing_masks_for_tracking_with_semantic_filter),\n                device=things_bbox_for_tracking.device)\n            bboxes, labels, ids = self.tracker.match(\n                bboxes=things_bbox_for_tracking,\n                labels=things_labels_for_tracking,\n                track_feats=track_feats,\n                frame_id=fid)\n            ids = ids + 1\n            ids[ids == -1] = 0\n        else:\n            ids = []\n\n        print(\"ids\", ids)\n\n        track_maps = self.generate_track_id_maps(ids, thing_masks_for_tracking, panoptic_seg)\n\n        semantic_map = self.get_semantic_seg(panoptic_seg, segments_info)\n\n        from scripts.visualizer import trackmap2rgb, cityscapes_cat2rgb, draw_bbox_on_img\n        vis_tracker = trackmap2rgb(track_maps)\n        vis_sem = cityscapes_cat2rgb(semantic_map)\n        if len(things_labels_for_tracking):\n            vis_tracker = draw_bbox_on_img(vis_tracker, things_bbox_for_tracking.cpu().numpy())\n\n        # Visualization usage\n        return semantic_map, track_maps, None, vis_sem, vis_tracker\n\n    def _track_forward(self, track_feats, x=None, mask_pred=None):\n        \"\"\"Track head forward function used in both training and testing.\n        We use mask pooling to get the fine grain features\"\"\"\n        # if not self.training:\n        #     mask_pred = [mask_pred]\n        # bbox_list = batch_mask2boxlist(mask_pred)\n        # track_rois = bboxlist2roi(bbox_list)\n        # track_rois = track_rois.clamp(min=0.0)\n        # track_feats = self.track_roi_extractor(x[:self.track_roi_extractor.num_inputs], track_rois)\n        track_feats = torch.cat(track_feats, 0)\n        # print(track_feats.shape)\n        # print(track_feats.shape)\n        # track_feats = track_feats\n\n        track_feats = self.track_head(track_feats)\n\n        return track_feats\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def get_things_id_for_tracking(self, panoptic_seg, seg_infos):\n        idxs = []\n        labels = []\n        masks = []\n        score = []\n        for segment in seg_infos:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                idxs.append(segment[\"instance_id\"])\n                labels.append(segment['category_id'])\n                score.append(segment['score'])\n        return idxs, labels, masks, score\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = \\\n        torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred = \\\n        torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + self.num_stuff_classes\n            else:\n                # for stuff (0- n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg\n\n    def generate_track_id_maps(self, ids, masks, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        if len(ids) == 0:\n            return final_id_maps\n        # assert len(things_mask_results) == len(track_results)\n        masks = masks.bool()\n        for i, id in enumerate(ids):\n            mask = masks[i].cpu().numpy()\n            final_id_maps[mask] = id\n        return final_id_maps"
  },
  {
    "path": "knet/video/knet_quansi_dense_embed_fc_joint_train.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob,\n                      build_activation_layer, build_norm_layer)\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone, build_roi_extractor\nfrom mmdet.core import build_assigner, build_sampler\nfrom knet.video.qdtrack.builder import build_tracker\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes, sem2ins_masks_kitti_step\nfrom unitrack.mask import tensor_mask2box\n\n\n@DETECTORS.register_module()\nclass VideoKNetQuansiEmbedFCJointTrain(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 track_localization_fpn=None,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 track_train_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 detach_mask_emd=False,\n                 cityscapes=False,\n                 kitti_step=False,\n                 cityscapes_short=False,\n                 vipseg=False,\n                 freeze_detector=False,\n                 semantic_filter=True,\n                 # linking parameters\n                 link_previous=False,\n                 bbox_roi_extractor=None,\n                 **kwargs):\n        super(VideoKNetQuansiEmbedFCJointTrain, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_train_cfg = track_train_cfg\n            self.track_head = build_head(track_head)\n            self.init_track_assigner_sampler()\n            if track_localization_fpn is not None:\n                self.track_localization_fpn = build_neck(track_localization_fpn)\n\n            if bbox_roi_extractor is not None:\n                self.track_roi_extractor = build_roi_extractor(\n                    bbox_roi_extractor)\n\n        if tracker is not None:\n            self.tracker_cfg = tracker\n\n        if freeze_detector:\n           self._freeze_detector()\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.num_proposals = self.rpn_head.num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step  # whether to train the kitti step panoptic segmentation\n        self.cityscapes_short = cityscapes_short  # whether to test the cityscape short panoptic segmentation\n        self.vipseg = vipseg  # whether to test the vip panoptic segmentation\n        self.semantic_filter = semantic_filter\n        self.link_previous = link_previous\n        self.detach_mask_emd = detach_mask_emd\n        # add embedding fcs for the final stage queries\n        num_emb_fcs = 1\n        act_cfg = dict(type='ReLU', inplace=True)\n        in_channels = 256\n        out_channels = 256\n        self.embed_fcs = nn.ModuleList()\n        for _ in range(num_emb_fcs):\n            self.embed_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.embed_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.embed_fcs.append(build_activation_layer(act_cfg))\n\n        self.fc_embed = nn.Linear(in_channels, out_channels)\n\n    def init_tracker(self):\n        self.tracker = build_tracker(self.tracker_cfg)\n\n    def _freeze_detector(self):\n\n        self.detector = [\n            self.rpn_head, self.roi_head\n        ]\n        for model in self.detector:\n            model.eval()\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def init_track_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler.\"\"\"\n\n        self.track_roi_assigner = build_assigner(\n            self.track_train_cfg.assigner)\n        self.track_share_assigner = False\n\n        self.track_roi_sampler = build_sampler(\n            self.track_train_cfg.sampler, context=self)\n        self.track_share_sampler = False\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n\n                if self.cityscapes or self.vipseg:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=list(range(self.num_stuff_classes,\n                                                      self.num_thing_classes + self.num_stuff_classes))\n                    )\n                elif self.kitti_step:\n                    sem_labels, sem_seg = sem2ins_masks_kitti_step(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=2,\n                        thing_label_in_seg=(11, 13))\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        img_h, img_w = batch_input_shape\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # prepare the gt_match_indices\n        gt_pids_list = []\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n\n        gt_match_indices = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_gt_labels, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        ref_rpn_results = self.rpn_head.forward_train(x_ref, ref_img_metas_new, ref_gt_masks,\n                                                      ref_labels_gt, ref_gt_sem_seg,\n                                                      ref_gt_sem_cls)\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_rpn_losses, ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores) = ref_rpn_results\n\n        losses_ref, ref_obj_feats, ref_cls_scores, ref_mask_preds, ref_scaled_mask_preds = self.roi_head.forward_train(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas,\n            ref_gt_masks,\n            ref_gt_labels,\n            gt_sem_seg=ref_gt_sem_seg,\n            gt_sem_cls=ref_gt_sem_cls,\n            imgs_whwh=None)\n\n        if self.link_previous:\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds, object_feats_track = self.roi_head.forward_train_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None,\n                previous_obj_feats=ref_obj_feats,\n                previous_mask_preds=ref_scaled_mask_preds,\n                previous_x_feats=ref_x_feats,\n            )\n        else:\n            # forward to get the current results\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None)\n\n        # ===== Tracking Part -==== #\n        # assign both key frame and reference frame tracking targets\n        key_sampling_results, ref_sampling_results = [], []\n        num_imgs = len(img_metas)\n\n        for i in range(num_imgs):\n            assign_result = self.track_roi_assigner.assign(\n                scaled_mask_preds[i][:self.num_proposals].detach(), cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                gt_masks[i], gt_labels[i], img_meta=img_metas[i])\n            sampling_result = self.track_roi_sampler.sample(\n                assign_result,\n                mask_preds[i][:self.num_proposals].detach(),\n                gt_masks[i])\n            key_sampling_results.append(sampling_result)\n\n            ref_assign_result = self.track_roi_assigner.assign(\n                ref_scaled_mask_preds[i][:self.num_proposals].detach(), ref_cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                ref_gt_masks[i], ref_gt_labels[i], img_meta=ref_img_metas_new[i])\n            ref_sampling_result = self.track_roi_sampler.sample(\n                ref_assign_result,\n                ref_mask_preds[i][:self.num_proposals].detach(),\n                ref_gt_masks[i])\n            ref_sampling_results.append(ref_sampling_result)\n\n        # current is tracking object\n        N, num_proposal, _, _, _ = object_feats_track.shape\n        emb_feat = object_feats_track.squeeze(-2).squeeze(-1)[:, :self.num_proposals, ]\n\n        for emb_layer in self.embed_fcs:\n            emb_feat = emb_layer(emb_feat)\n        object_feats_embed = self.fc_embed(emb_feat).view(N, self.num_proposals, -1)\n\n\n        ref_emb_feat = ref_obj_feats.squeeze(-2).squeeze(-1)[:, :self.num_proposals, ]\n        for emb_layer in self.embed_fcs:\n            ref_emb_feat = emb_layer(ref_emb_feat)\n        ref_object_feats_embed = self.fc_embed(ref_emb_feat).view(N, self.num_proposals, -1)\n\n        # sampling predicted GT mask\n        key_emb_indexs = [res.pos_inds for res in key_sampling_results]\n        object_feats_embed_list = []\n        for i in range(len(key_emb_indexs)):\n            object_feats_embed_list.append(object_feats_embed[:, key_emb_indexs[i], :].squeeze(0))\n\n        key_feats = self._track_forward(object_feats_embed_list)\n\n        ref_emb_indexs = [res.pos_inds for res in ref_sampling_results]\n        ref_object_feats_embed_list = []\n        for i in range(len(ref_emb_indexs)):\n            ref_object_feats_embed_list.append(ref_object_feats_embed[:, ref_emb_indexs[i], :].squeeze(0))\n\n        ref_feats = self._track_forward(ref_object_feats_embed_list)\n\n        match_feats = self.track_head.match(key_feats, ref_feats,\n                                            key_sampling_results,\n                                            ref_sampling_results)\n\n        asso_targets = self.track_head.get_track_targets(\n            gt_match_indices, key_sampling_results, ref_sampling_results)\n        loss_track = self.track_head.loss(*match_feats, *asso_targets)\n\n        ref_losses = self.add_ref_loss(losses_ref)\n        ref_rpn_losses = self.add_ref_rpn_loss(ref_rpn_losses)\n\n        losses.update(ref_rpn_losses)\n        losses.update(rpn_losses)\n        losses.update(ref_losses)\n        losses.update(loss_track)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None, **kwargs):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n\n        # set the dataset type\n        if self.cityscapes and not self.kitti_step and not self.cityscapes_short and not self.vipseg:\n            iid = img_metas[0]['iid']\n            fid = iid % 10000\n            is_first = (fid == 1)\n        else:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 0)\n\n        # for current frame\n        x = self.extract_feat(img)\n        # current frame inference\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        # init tracker\n        if is_first:\n            self.init_tracker()\n            self.obj_feats_memory = None\n            self.x_feats_memory = None\n            self.mask_preds_memory = None\n            print(\"fid\", fid)\n\n        # wheter to link the previous\n        if self.link_previous:\n            cur_segm_results, obj_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                previous_obj_feats=self.obj_feats_memory,\n                previous_mask_preds=self.mask_preds_memory,\n                previous_x_feats=self.x_feats_memory,\n                is_first=is_first\n            )\n            self.obj_feats_memory = obj_feats\n            self.x_feats_memory = x_feats\n            self.mask_preds_memory = scaled_mask_preds\n        else:\n            cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas)\n\n        # for tracking part\n        _, segm_result, mask_preds, panoptic_result, query_output = cur_segm_results[0]\n        panoptic_seg, segments_info = panoptic_result\n\n        # get sorted tracking thing ids, labels, masks, score for tracking\n        things_index_for_tracking, things_labels_for_tracking, thing_masks_for_tracking, things_score_for_tracking = \\\n            self.get_things_id_for_tracking(panoptic_seg, segments_info)\n        things_labels_for_tracking = torch.Tensor(things_labels_for_tracking).to(cls_scores.device).long()\n\n        # get the semantic filter\n        if self.semantic_filter:\n            seg_preds = torch.nn.functional.interpolate(seg_preds, panoptic_seg.shape, mode='bilinear',\n                                                        align_corners=False)\n            seg_preds = seg_preds.sigmoid()\n            seg_out = seg_preds.argmax(1)\n            semantic_thing = (seg_out < self.num_thing_classes).to(dtype=torch.float32)\n        else:\n            semantic_thing = 1.\n\n        if len(things_labels_for_tracking) > 0:\n            things_bbox_for_tracking = torch.zeros((len(things_score_for_tracking), 5),\n                                                   dtype=torch.float, device=x_feats.device)\n            things_bbox_for_tracking[:, 4] = torch.tensor(things_score_for_tracking,\n                                                          device=things_bbox_for_tracking.device)\n\n            thing_masks_for_tracking_final = []\n            for mask in thing_masks_for_tracking:\n                thing_masks_for_tracking_final.append(torch.Tensor(mask).unsqueeze(0).to(\n                    x_feats.device).float())\n            thing_masks_for_tracking_final = torch.cat(thing_masks_for_tracking_final, 0)\n            thing_masks_for_tracking = thing_masks_for_tracking_final\n            thing_masks_for_tracking_with_semantic_filter = thing_masks_for_tracking_final * semantic_thing\n\n        if len(things_labels_for_tracking) == 0:\n            track_feats = None\n        else:\n            # tracking embeddings\n            N, _, _, _ = query_output.shape\n            emb_feat = query_output.squeeze(-2).squeeze(-1).unsqueeze(0)  # (n,d,1,1) -> (1,n,d)\n\n            for emb_layer in self.embed_fcs:\n                emb_feat = emb_layer(emb_feat)\n            object_feats_embed = self.fc_embed(emb_feat).view(1, N, -1)\n            object_feats_embed_for_tracking = object_feats_embed.squeeze(0)\n            track_feats = self._track_forward([object_feats_embed_for_tracking])\n\n        if track_feats is not None:\n            things_bbox_for_tracking[:, :4] = torch.tensor(tensor_mask2box(thing_masks_for_tracking_with_semantic_filter),\n                                                           device=things_bbox_for_tracking.device)\n            bboxes, labels, ids = self.tracker.match(\n                bboxes=things_bbox_for_tracking,\n                labels=things_labels_for_tracking,\n                track_feats=track_feats,\n                frame_id=fid)\n\n            ids = ids + 1\n            ids[ids == -1] = 0\n\n            # print(\"track feats:\", track_feats[0])\n            # print(\"id\", ids)\n\n        else:\n            ids = []\n\n\n        track_maps = self.generate_track_id_maps(ids, thing_masks_for_tracking, panoptic_seg)\n\n        semantic_map = self.get_semantic_seg(panoptic_seg, segments_info)\n\n        from scripts.visualizer import trackmap2rgb, cityscapes_cat2rgb, draw_bbox_on_img\n        vis_tracker = trackmap2rgb(track_maps)\n        vis_sem = cityscapes_cat2rgb(semantic_map)\n        if len(things_labels_for_tracking):\n            vis_tracker = draw_bbox_on_img(vis_tracker, things_bbox_for_tracking.cpu().numpy())\n\n        # Visualization usage\n        return semantic_map, track_maps, None, vis_sem, vis_tracker\n\n    def _track_forward(self, track_feats, x=None, mask_pred=None):\n        \"\"\"Track head forward function used in both training and testing.\n        We use mask pooling to get the fine grain features\"\"\"\n        # if not self.training:\n        #     mask_pred = [mask_pred]\n        track_feats = torch.cat(track_feats, 0)\n\n        track_feats = self.track_head(track_feats)\n\n        return track_feats\n\n    def forward_dummy(self, img, img_metas=None):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(0, 0, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.simple_test_mask_preds(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def get_things_id_for_tracking(self, panoptic_seg, seg_infos):\n        idxs = []\n        labels = []\n        masks = []\n        score = []\n        for segment in seg_infos:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                idxs.append(segment[\"instance_id\"])\n                labels.append(segment['category_id'])\n                score.append(segment['score'])\n        return idxs, labels, masks, score\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred = torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                # for things\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:   # city and vip_seg\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + self.num_stuff_classes\n            else:\n                # for stuff (0 - n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:   # city and vip_seg\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg\n\n    def generate_track_id_maps(self, ids, masks, panopitc_seg_maps):\n\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n\n        if len(ids) == 0:\n            return final_id_maps\n        masks = masks.bool()\n\n        for i, id in enumerate(ids):\n            mask = masks[i].cpu().numpy()\n            final_id_maps[mask] = id\n\n        return final_id_maps\n\n    def add_ref_loss(self, loss_dict):\n        track_loss ={}\n        for k, v in loss_dict.items():\n            track_loss[str(k)+\"_ref\"] = v\n        return track_loss\n\n    def add_ref_rpn_loss(self, loss_dict):\n        ref_rpn_loss = {}\n        for k, v in loss_dict.items():\n            ref_rpn_loss[str(k) +\"_ref_rpn\"] = v\n        return ref_rpn_loss"
  },
  {
    "path": "knet/video/knet_quansi_dense_embed_fc_toy_exp.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone, build_roi_extractor\nfrom mmdet.core import build_assigner, build_sampler\nfrom knet.video.qdtrack.builder import build_tracker\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes, sem2ins_masks_kitti_step\nfrom unitrack.mask import tensor_mask2box\n\n\n@DETECTORS.register_module()\nclass VideoKNetQuansiEmbedFCToy(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by directly propagation the kernels.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 track_localization_fpn=None,\n                 track_mhsa=False,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 track_train_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 detach_mask_emd=False,\n                 cityscapes=False,\n                 kitti_step=False,\n                 freeze_detector=False,\n                 semantic_filter=True,\n                 link_previous=False,\n                 bbox_roi_extractor=dict(\n                     type='SingleRoIExtractor',\n                     roi_layer=dict(\n                         type='RoIAlign', output_size=7, sampling_ratio=2),\n                     out_channels=256,\n                     featmap_strides=[4, 8, 16, 32]),\n                 **kwargs):\n        super(VideoKNetQuansiEmbedFCToy, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_train_cfg = track_train_cfg\n            self.track_head = build_head(track_head)\n            self.init_track_assigner_sampler()\n            if track_localization_fpn is not None:\n                self.track_localization_fpn = build_neck(track_localization_fpn)\n            if bbox_roi_extractor is not None:\n                self.track_roi_extractor = build_roi_extractor(\n                    bbox_roi_extractor)\n\n        if tracker is not None:\n            self.tracker_cfg = tracker\n\n        if freeze_detector:\n           self._freeze_detector()\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.num_proposals = self.rpn_head.num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step  # whether to train the kitti step panoptic segmentation\n\n        self.semantic_filter = semantic_filter\n        self.link_previous = link_previous\n        self.detach_mask_emd = detach_mask_emd\n        self.track_mhsa = track_mhsa\n        # add embedding fcs for the final stage queries\n        # num_emb_fcs = 1\n        # act_cfg = dict(type='ReLU', inplace=True)\n        # in_channels = 256\n        # out_channels = 256\n        # self.embed_fcs = nn.ModuleList()\n        # for _ in range(num_emb_fcs):\n        #     self.embed_fcs.append(\n        #         nn.Linear(in_channels, in_channels, bias=False))\n        #     self.embed_fcs.append(\n        #         build_norm_layer(dict(type='LN'), in_channels)[1])\n        #     self.embed_fcs.append(build_activation_layer(act_cfg))\n        #\n        # self.fc_embed = nn.Linear(in_channels, out_channels)\n\n    def init_tracker(self):\n        self.tracker = build_tracker(self.tracker_cfg)\n\n    def _freeze_detector(self):\n\n        self.detector = [\n            self.rpn_head, self.roi_head\n        ]\n        for model in self.detector:\n            model.eval()\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def init_track_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler.\"\"\"\n\n        self.track_roi_assigner = build_assigner(\n            self.track_train_cfg.assigner)\n        self.track_share_assigner = False\n\n        self.track_roi_sampler = build_sampler(\n            self.track_train_cfg.sampler, context=self)\n        self.track_share_sampler = False\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=list(range(self.num_stuff_classes,\n                                                      self.num_thing_classes + self.num_stuff_classes))\n                    )\n                elif self.kitti_step:\n                    sem_labels, sem_seg = sem2ins_masks_kitti_step(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=2,\n                        thing_label_in_seg=(11, 13))\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        img_h, img_w = batch_input_shape\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # prepare the gt_match_indices\n        gt_pids_list = []\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n\n        gt_match_indices = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_gt_labels, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        # simple forward to get the reference results\n        self.rpn_head.eval()\n        ref_rpn_results = self.rpn_head.simple_test_rpn(x_ref, ref_img_metas_new)\n        self.rpn_head.train()\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores, ref_seg_preds) = ref_rpn_results\n\n        ref_obj_feats,  ref_cls_scores, ref_mask_preds, ref_scaled_mask_preds = self.roi_head.simple_test_mask_preds(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas_new,\n           )\n\n        if self.link_previous:\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds, object_feats_track = self.roi_head.forward_train_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None,\n                previous_obj_feats=ref_obj_feats,\n                previous_mask_preds=ref_scaled_mask_preds,\n                previous_x_feats=ref_x_feats,\n            )\n        else:\n            # forward to get the current results\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None)\n\n        # ===== Tracking Part -==== #\n        # assign both key frame and reference frame tracking targets\n        key_sampling_results, ref_sampling_results = [], []\n        num_imgs = len(img_metas)\n\n        for i in range(num_imgs):\n            assign_result = self.track_roi_assigner.assign(\n                scaled_mask_preds[i][:self.num_proposals].detach(), cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                gt_masks[i], gt_labels[i], img_meta=img_metas[i])\n            sampling_result = self.track_roi_sampler.sample(\n                assign_result,\n                mask_preds[i][:self.num_proposals].detach(),\n                gt_masks[i])\n            key_sampling_results.append(sampling_result)\n\n            ref_assign_result = self.track_roi_assigner.assign(\n                ref_scaled_mask_preds[i][:self.num_proposals].detach(), ref_cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                ref_gt_masks[i], ref_gt_labels[i], img_meta=ref_img_metas_new[i])\n            ref_sampling_result = self.track_roi_sampler.sample(\n                ref_assign_result,\n                ref_mask_preds[i][:self.num_proposals].detach(),\n                ref_gt_masks[i])\n            ref_sampling_results.append(ref_sampling_result)\n        if self.detach_mask_emd:\n            object_feats = object_feats.detach()\n            ref_obj_feats = ref_obj_feats.detach()\n\n        if self.link_previous:\n            object_feats = object_feats_track\n\n        N, num_proposal, _, _, _ = object_feats.shape\n        emb_feat = object_feats.squeeze(-2).squeeze(-1)[:, :self.num_proposals, ]\n\n        for emb_layer in self.embed_fcs:\n            emb_feat = emb_layer(emb_feat)\n        object_feats_embed = self.fc_embed(emb_feat).view(N, self.num_proposals, -1)\n\n\n        ref_emb_feat = ref_obj_feats.squeeze(-2).squeeze(-1)[:, :self.num_proposals, ]\n        for emb_layer in self.embed_fcs:\n            ref_emb_feat = emb_layer(ref_emb_feat)\n        ref_object_feats_embed = self.fc_embed(ref_emb_feat).view(N, self.num_proposals, -1)\n\n        # sampling predicted GT mask\n        key_emb_indexs = [res.pos_inds for res in key_sampling_results]\n        object_feats_embed_list = []\n        for i in range(len(key_emb_indexs)):\n            object_feats_embed_list.append(object_feats_embed[:, key_emb_indexs[i], :].squeeze(0))\n\n        key_feats = self._track_forward(object_feats_embed_list)\n\n        ref_emb_indexs = [res.pos_inds for res in ref_sampling_results]\n        ref_object_feats_embed_list = []\n        for i in range(len(ref_emb_indexs)):\n            ref_object_feats_embed_list.append(ref_object_feats_embed[:, ref_emb_indexs[i], :].squeeze(0))\n\n        ref_feats = self._track_forward(ref_object_feats_embed_list)\n\n        match_feats = self.track_head.match(key_feats, ref_feats,\n                                            key_sampling_results,\n                                            ref_sampling_results)\n\n        asso_targets = self.track_head.get_track_targets(\n            gt_match_indices, key_sampling_results, ref_sampling_results)\n        loss_track = self.track_head.loss(*match_feats, *asso_targets)\n\n        losses.update(loss_track)\n        losses.update(rpn_losses)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None, **kwargs):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n        # set the dataset type\n        # whether is the first frame for such clips\n        if self.cityscapes and not self.kitti_step:\n            iid = img_metas[0]['iid']\n            fid = iid % 10000\n            is_first = (fid == 1)\n        else:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 0)\n\n        if is_first:\n            self.init_tracker()\n            self.obj_feats_memory = None\n            self.x_feats_memory = None\n            self.mask_preds_memory = None\n\n        # for current frame\n        x = self.extract_feat(img)\n        # current frame inference\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas)\n\n        # for tracking part\n        _, segm_result, mask_preds, panoptic_result, query_output = cur_segm_results[0]\n        panoptic_seg, segments_info = panoptic_result\n\n\n        # get sorted tracking thing ids, labels, masks, score for tracking\n        things_index_for_tracking, things_labels_for_tracking, thing_masks_for_tracking, things_score_for_tracking = \\\n            self.get_things_id_for_tracking(panoptic_seg, segments_info)\n        things_labels_for_tracking = torch.Tensor(things_labels_for_tracking).to(cls_scores.device).long()\n\n        # get the semantic filter\n        if self.semantic_filter:\n            seg_preds = torch.nn.functional.interpolate(seg_preds, panoptic_seg.shape, mode='bilinear',\n                                                        align_corners=False)\n            seg_preds = seg_preds.sigmoid()\n            seg_out = seg_preds.argmax(1)\n            semantic_thing = (seg_out < self.num_thing_classes).to(dtype=torch.float32)\n        else:\n            semantic_thing = 1.\n\n        if len(things_labels_for_tracking) > 0:\n            things_bbox_for_tracking = torch.zeros((len(things_score_for_tracking), 5),\n                                                   dtype=torch.float, device=x_feats.device)\n            things_bbox_for_tracking[:, 4] = torch.tensor(things_score_for_tracking,\n                                                          device=things_bbox_for_tracking.device)\n\n            thing_masks_for_tracking_final = []\n            for mask in thing_masks_for_tracking:\n                thing_masks_for_tracking_final.append(torch.Tensor(mask).unsqueeze(0).to(\n                    x_feats.device).float())\n            thing_masks_for_tracking_final = torch.cat(thing_masks_for_tracking_final, 0)\n            thing_masks_for_tracking = thing_masks_for_tracking_final\n            thing_masks_for_tracking_with_semantic_filter = thing_masks_for_tracking_final * semantic_thing\n\n        if len(things_labels_for_tracking) == 0:\n            track_feats = None\n        else:\n            # tracking embeddings\n            N, _, _, _ = query_output.shape\n            emb_feat = query_output.squeeze(-2).squeeze(-1).unsqueeze(0)  # (n,d,1,1) -> (1,n,d)\n\n            # for emb_layer in self.embed_fcs:\n            #     emb_feat = emb_layer(emb_feat)\n            # object_feats_embed = self.fc_embed(emb_feat).view(1, N, -1)\n\n            track_feats = emb_feat.squeeze(0)\n            # tracking embedding features\n            # track_feats = self._track_forward([object_feats_embed_for_tracking])\n\n        if track_feats is not None:\n            things_bbox_for_tracking[:, :4] = torch.tensor(tensor_mask2box(thing_masks_for_tracking_with_semantic_filter),\n                                                           device=things_bbox_for_tracking.device)\n            bboxes, labels, ids = self.tracker.match(\n                bboxes=things_bbox_for_tracking,\n                labels=things_labels_for_tracking,\n                track_feats=track_feats,\n                frame_id=fid)\n            ids = ids + 1\n            ids[ids == -1] = 0\n        else:\n            ids = []\n\n        track_maps = self.generate_track_id_maps(ids, thing_masks_for_tracking, panoptic_seg)\n\n        semantic_map = self.get_semantic_seg(panoptic_seg, segments_info)\n\n        from scripts.visualizer import trackmap2rgb, cityscapes_cat2rgb, draw_bbox_on_img\n        vis_tracker = trackmap2rgb(track_maps)\n        vis_sem = cityscapes_cat2rgb(semantic_map)\n        if len(things_labels_for_tracking):\n            vis_tracker = draw_bbox_on_img(vis_tracker, things_bbox_for_tracking.cpu().numpy())\n\n        # Visualization usage\n        return semantic_map, track_maps, None, vis_sem, vis_tracker\n\n    def _track_forward(self, track_feats, x=None, mask_pred=None):\n        \"\"\"Track head forward function used in both training and testing.\n        We use mask pooling to get the fine grain features\"\"\"\n        # if not self.training:\n        #     mask_pred = [mask_pred]\n        # bbox_list = batch_mask2boxlist(mask_pred)\n        # track_rois = bboxlist2roi(bbox_list)\n        # track_rois = track_rois.clamp(min=0.0)\n        # track_feats = self.track_roi_extractor(x[:self.track_roi_extractor.num_inputs], track_rois)\n        track_feats = torch.cat(track_feats, 0)\n        # print(track_feats.shape)\n        # print(track_feats.shape)\n        # track_feats = track_feats\n\n        track_feats = self.track_head(track_feats)\n\n        return track_feats\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def get_things_id_for_tracking(self, panoptic_seg, seg_infos):\n        idxs = []\n        labels = []\n        masks = []\n        score = []\n        for segment in seg_infos:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                idxs.append(segment[\"instance_id\"])\n                labels.append(segment['category_id'])\n                score.append(segment['score'])\n        return idxs, labels, masks, score\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred = torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + self.num_stuff_classes\n            else:\n                # for stuff (0- n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg\n\n    def generate_track_id_maps(self, ids, masks, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        if len(ids) == 0:\n            return final_id_maps\n        # assert len(things_mask_results) == len(track_results)\n        masks = masks.bool()\n        for i, id in enumerate(ids):\n            mask = masks[i].cpu().numpy()\n            final_id_maps[mask] = id\n        return final_id_maps"
  },
  {
    "path": "knet/video/knet_quansi_dense_roi_gt_box.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone, build_roi_extractor\nfrom mmdet.core import build_assigner, build_sampler\nfrom knet.video.qdtrack.builder import build_tracker\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes, sem2ins_masks_kitti_step\nfrom unitrack.mask import tensor_mask2box\nfrom unitrack.utils.mask import mask2box, batch_mask2boxlist, bboxlist2roi\n\n@DETECTORS.register_module()\nclass VideoKNetQuansiTrackROIGTBox(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 track_localization_fpn=None,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 track_train_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 kitti_step=False,\n                 freeze_detector=False,\n                 semantic_filter=False,\n                 # linking parameters\n                 link_previous=False,\n                 bbox_roi_extractor=dict(\n                     type='SingleRoIExtractor',\n                     roi_layer=dict(\n                         type='RoIAlign', output_size=7, sampling_ratio=2),\n                     out_channels=256,\n                     featmap_strides=[4, 8, 16, 32]),\n                 **kwargs):\n        super(VideoKNetQuansiTrackROIGTBox, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_train_cfg = track_train_cfg\n            self.track_head = build_head(track_head)\n            self.init_track_assigner_sampler()\n            if track_localization_fpn is not None:\n                self.track_localization_fpn = build_neck(track_localization_fpn)\n\n            self.track_roi_extractor = build_roi_extractor(\n                bbox_roi_extractor)\n\n        if tracker is not None:\n            self.tracker_cfg = tracker\n\n        if freeze_detector:\n           self._freeze_detector()\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.num_proposals = self.rpn_head.num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step  # whether to train the kitti step panoptic segmentation\n\n        self.semantic_filter = semantic_filter\n        self.link_previous = link_previous\n\n    def init_tracker(self):\n        self.tracker = build_tracker(self.tracker_cfg)\n\n    def _freeze_detector(self):\n\n        self.detector = [\n            self.rpn_head, self.roi_head\n        ]\n        for model in self.detector:\n            model.eval()\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def init_track_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler.\"\"\"\n\n        self.track_roi_assigner = build_assigner(\n            self.track_train_cfg.assigner)\n        self.track_share_assigner = False\n\n        self.track_roi_sampler = build_sampler(\n            self.track_train_cfg.sampler, context=self)\n        self.track_share_sampler = False\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=list(range(self.num_stuff_classes,\n                                                      self.num_thing_classes + self.num_stuff_classes))\n                    )\n                elif self.kitti_step:\n                    sem_labels, sem_seg = sem2ins_masks_kitti_step(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=2,\n                        thing_label_in_seg=(11, 13))\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        img_h, img_w = batch_input_shape\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # prepare the gt_match_indices\n        gt_pids_list = []\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n\n        gt_match_indices = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_gt_labels, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        # simple forward to get the reference results\n        self.rpn_head.eval()\n        ref_rpn_results = self.rpn_head.simple_test_rpn(x_ref, ref_img_metas_new)\n        self.rpn_head.train()\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores, ref_seg_preds) = ref_rpn_results\n\n        ref_obj_feats,  ref_cls_scores, ref_mask_preds, ref_scaled_mask_preds = self.roi_head.simple_test_mask_preds(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas_new,\n           )\n\n        if self.link_previous:\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None,\n                previous_obj_feats=ref_obj_feats,\n                previous_mask_preds=ref_scaled_mask_preds,\n                previous_x_feats=ref_x_feats,\n            )\n        else:\n            # forward to get the current results\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None)\n\n        # ===== Tracking Part -==== #\n        # assign both key frame and reference frame tracking targets\n        key_sampling_results, ref_sampling_results = [], []\n        num_imgs = len(img_metas)\n\n        for i in range(num_imgs):\n            assign_result = self.track_roi_assigner.assign(\n                scaled_mask_preds[i][:self.num_proposals].detach(), cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                gt_masks[i], gt_labels[i], img_meta=img_metas[i])\n            sampling_result = self.track_roi_sampler.sample(\n                assign_result,\n                mask_preds[i][:self.num_proposals].detach(),\n                gt_masks[i])\n            key_sampling_results.append(sampling_result)\n\n            ref_assign_result = self.track_roi_assigner.assign(\n                ref_scaled_mask_preds[i][:self.num_proposals].detach(), ref_cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                ref_gt_masks[i], ref_gt_labels[i], img_meta=ref_img_metas_new[i])\n            ref_sampling_result = self.track_roi_sampler.sample(\n                ref_assign_result,\n                ref_mask_preds[i][:self.num_proposals].detach(),\n                ref_gt_masks[i])\n            ref_sampling_results.append(ref_sampling_result)\n\n        # roi feature embeddings\n        key_masks = [res.pos_gt_masks for res in key_sampling_results]\n        for i in range(len(key_masks)):\n            key_masks[i] = F.interpolate(key_masks[i].unsqueeze(0),\n                                        size=(img_h, img_w), mode=\"bilinear\", align_corners=False).squeeze(0)\n            key_masks[i] = (key_masks[i].sigmoid() > 0.5).float()\n\n        key_feats = self._track_forward(x, key_masks)\n\n        # roi feature embeddings\n        ref_masks = [res.pos_gt_masks for res in ref_sampling_results]\n        for i in range(len(ref_masks)):\n            ref_masks[i] = F.interpolate(ref_masks[i].unsqueeze(0),\n                                        size=(img_h, img_w), mode=\"bilinear\", align_corners=False).squeeze(0)\n            ref_masks[i] = (ref_masks[i].sigmoid() > 0.5).float()\n\n        ref_feats = self._track_forward(x_ref, ref_masks)\n\n        match_feats = self.track_head.match(key_feats, ref_feats,\n                                            key_sampling_results,\n                                            ref_sampling_results)\n\n        asso_targets = self.track_head.get_track_targets(\n            gt_match_indices, key_sampling_results, ref_sampling_results)\n        loss_track = self.track_head.loss(*match_feats, *asso_targets)\n\n        losses.update(loss_track)\n        losses.update(rpn_losses)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None, **kwargs):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n\n        # whether is the first frame for such clips\n        # assert 'city' in img_metas[0]['filename'] and 'iid' in img_metas[0]\n        if \"city\" in img_metas[0]['filename']:\n            iid = img_metas[0]['iid']\n            fid = iid % 10000\n            is_first = (fid == 1)\n        elif \"motchallenge\" in img_metas[0]['filename']:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 1)\n            if is_first:\n                print(\"First detected on {}\".format(fid))\n        else:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 0)\n\n        if is_first:\n            self.init_tracker()\n            self.obj_feats_memory = None\n            self.x_feats_memory = None\n            self.mask_preds_memory = None\n\n        # for current frame\n        x = self.extract_feat(img)\n        # current frame inference\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        if self.link_previous:\n            cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                previous_obj_feats=self.obj_feats_memory,\n                previous_mask_preds=self.mask_preds_memory,\n                previous_x_feats=self.x_feats_memory,\n            )\n            self.obj_feats_memory = query_output\n            self.x_feats_memory = x_feats\n            self.mask_preds_memory = scaled_mask_preds\n        else:\n            cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas)\n\n        # for tracking part\n        _, segm_result, mask_preds, panoptic_result, _ = cur_segm_results[0]\n        panoptic_seg, segments_info = panoptic_result\n\n        if self.semantic_filter:\n            seg_preds = torch.nn.functional.interpolate(seg_preds, panoptic_seg.shape, mode='bilinear', align_corners=False)\n            seg_preds = seg_preds.sigmoid()\n            seg_out = seg_preds.argmax(1)\n            semantic_thing = (seg_out < self.num_thing_classes).to(dtype=torch.float32)\n        else:\n            semantic_thing = 1.\n\n        # get sorted tracking thing ids, labels, masks, score for tracking\n        things_index_for_tracking, things_labels_for_tracking, thing_masks_for_tracking, things_score_for_tracking = \\\n            self.get_things_id_for_tracking(panoptic_seg, segments_info)\n        things_labels_for_tracking = torch.Tensor(things_labels_for_tracking).to(cls_scores.device).long()\n        if len(things_labels_for_tracking) > 0:\n            things_bbox_for_tracking = torch.zeros((len(things_score_for_tracking), 5),\n                                                   dtype=torch.float, device=x_feats.device)\n            things_bbox_for_tracking[:, 4] = torch.tensor(things_score_for_tracking,\n                                                          device=things_bbox_for_tracking.device)\n\n            thing_masks_for_tracking_final = []\n            for mask in thing_masks_for_tracking:\n                thing_masks_for_tracking_final.append(torch.Tensor(mask).unsqueeze(0).to(\n                    x_feats.device).float())\n            thing_masks_for_tracking_final = torch.cat(thing_masks_for_tracking_final, 0)\n            thing_masks_for_tracking = thing_masks_for_tracking_final\n            thing_masks_for_tracking_with_semantic_filter = thing_masks_for_tracking_final * semantic_thing\n\n        if len(things_labels_for_tracking) == 0:\n            track_feats = None\n        else:\n            # tracking embedding features\n            track_feats = self._track_forward(x, thing_masks_for_tracking_with_semantic_filter)\n\n        if track_feats is not None:\n            # assert len(things_id_for_tracking) == len(things_labels_for_tracking)\n            things_bbox_for_tracking[:, :4] = torch.tensor(tensor_mask2box(thing_masks_for_tracking_with_semantic_filter),\n                                                           device=things_bbox_for_tracking.device)\n            bboxes, labels, ids = self.tracker.match(\n                bboxes=things_bbox_for_tracking,\n                labels=things_labels_for_tracking,\n                track_feats=track_feats,\n                frame_id=fid)\n            ids = ids + 1\n            ids[ids == -1] = 0\n        else:\n            ids = []\n\n        track_maps = self.generate_track_id_maps(ids, thing_masks_for_tracking, panoptic_seg)\n\n        semantic_map = self.get_semantic_seg(panoptic_seg, segments_info)\n\n        from scripts.visualizer import trackmap2rgb, cityscapes_cat2rgb, draw_bbox_on_img\n        vis_tracker = trackmap2rgb(track_maps)\n        vis_sem = cityscapes_cat2rgb(semantic_map)\n        if len(things_labels_for_tracking):\n            vis_tracker = draw_bbox_on_img(vis_tracker, things_bbox_for_tracking.cpu().numpy())\n\n        # Visualization end\n        return semantic_map, track_maps, None, vis_sem, vis_tracker\n\n\n    def _track_forward(self, x, mask_pred):\n        \"\"\"Track head forward function used in both training and testing.\n        We use mask pooling to get the fine grain features\"\"\"\n        if not self.training:\n            mask_pred = [mask_pred]\n        bbox_list = batch_mask2boxlist(mask_pred)\n        track_rois = bboxlist2roi(bbox_list)\n        track_rois = track_rois.clamp(min=0.0)\n        track_feats = self.track_roi_extractor(x[:self.track_roi_extractor.num_inputs], track_rois)\n        track_feats = self.track_head(track_feats)\n\n        return track_feats\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def get_things_id_for_tracking(self, panoptic_seg, seg_infos):\n        idxs = []\n        labels = []\n        masks = []\n        score = []\n        for segment in seg_infos:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                idxs.append(segment[\"instance_id\"])\n                labels.append(segment['category_id'])\n                score.append(segment['score'])\n        return idxs, labels, masks, score\n\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred = torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg\n\n    def generate_track_id_maps(self, ids, masks, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        if len(ids) == 0:\n            return final_id_maps\n        # assert len(things_mask_results) == len(track_results)\n        masks = masks.bool()\n        for i, id in enumerate(ids):\n            mask = masks[i].cpu().numpy()\n            final_id_maps[mask] = id\n        return final_id_maps"
  },
  {
    "path": "knet/video/knet_quansi_dense_roi_gt_box_joint_train.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone, build_roi_extractor\nfrom mmdet.core import build_assigner, build_sampler\nfrom knet.video.qdtrack.builder import build_tracker\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes, sem2ins_masks_kitti_step\nfrom unitrack.mask import tensor_mask2box\nfrom unitrack.utils.mask import mask2box, batch_mask2boxlist, bboxlist2roi\n\n# RoI box based Video K-Net baseline.\n@DETECTORS.register_module()\nclass VideoKNetQuansiTrackROIGTBoxJointTrain(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 track_localization_fpn=None,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 track_train_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 kitti_step=False,\n                 freeze_detector=False,\n                 semantic_filter=False,\n                 # linking parameters\n                 link_previous=False,\n                 bbox_roi_extractor=dict(\n                     type='SingleRoIExtractor',\n                     roi_layer=dict(\n                         type='RoIAlign', output_size=7, sampling_ratio=2),\n                     out_channels=256,\n                     featmap_strides=[4, 8, 16, 32]),\n                 **kwargs):\n        super(VideoKNetQuansiTrackROIGTBoxJointTrain, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_train_cfg = track_train_cfg\n            self.track_head = build_head(track_head)\n            self.init_track_assigner_sampler()\n            if track_localization_fpn is not None:\n                self.track_localization_fpn = build_neck(track_localization_fpn)\n\n            self.track_roi_extractor = build_roi_extractor(\n                bbox_roi_extractor)\n\n        if tracker is not None:\n            self.tracker_cfg = tracker\n\n        if freeze_detector:\n           self._freeze_detector()\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.num_proposals = self.rpn_head.num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step  # whether to train the kitti step panoptic segmentation\n\n        self.semantic_filter = semantic_filter\n        self.link_previous = link_previous\n\n    def init_tracker(self):\n        self.tracker = build_tracker(self.tracker_cfg)\n\n    def _freeze_detector(self):\n\n        self.detector = [\n            self.rpn_head, self.roi_head\n        ]\n        for model in self.detector:\n            model.eval()\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def init_track_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler.\"\"\"\n\n        self.track_roi_assigner = build_assigner(\n            self.track_train_cfg.assigner)\n        self.track_share_assigner = False\n\n        self.track_roi_sampler = build_sampler(\n            self.track_train_cfg.sampler, context=self)\n        self.track_share_sampler = False\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                elif self.kitti_step:\n                    sem_labels, sem_seg = sem2ins_masks_kitti_step(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=2,\n                        thing_label_in_seg=(11, 13))\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        img_h, img_w = batch_input_shape\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # prepare the gt_match_indices\n        gt_pids_list = []\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n\n        gt_match_indices = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_gt_labels, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n\n        # current frame\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        # simple forward to get the reference results\n        # self.rpn_head.eval()\n        # ref_rpn_results = self.rpn_head.simple_test_rpn(x_ref, ref_img_metas_new)\n        # self.rpn_head.train()\n\n        # reference frame\n        ref_rpn_results = self.rpn_head.forward_train(x_ref, ref_img_metas_new, ref_gt_masks,\n                                                  ref_labels_gt, ref_gt_sem_seg,\n                                                  ref_gt_sem_cls)\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_rpn_losses, ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores) = ref_rpn_results\n\n        losses_ref, ref_obj_feats, ref_cls_scores, ref_mask_preds, ref_scaled_mask_preds = self.roi_head.forward_train(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas,\n            ref_gt_masks,\n            ref_gt_labels,\n            gt_sem_seg=ref_gt_sem_seg,\n            gt_sem_cls=ref_gt_sem_cls,\n            imgs_whwh=None)\n\n\n        if self.link_previous:\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None,\n                previous_obj_feats=ref_obj_feats,\n                previous_mask_preds=ref_scaled_mask_preds,\n                previous_x_feats=ref_x_feats,\n            )\n        else:\n            # forward to get the current results\n            losses, object_feats, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.forward_train(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                gt_masks,\n                gt_labels,\n                gt_bboxes_ignore=gt_bboxes_ignore,\n                gt_bboxes=gt_bboxes,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls,\n                imgs_whwh=None)\n\n        # ===== Tracking Part -==== #\n        # assign both key frame and reference frame tracking targets\n        key_sampling_results, ref_sampling_results = [], []\n        num_imgs = len(img_metas)\n\n        for i in range(num_imgs):\n            assign_result = self.track_roi_assigner.assign(\n                scaled_mask_preds[i][:self.num_proposals].detach(), cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                gt_masks[i], gt_labels[i], img_meta=img_metas[i])\n            sampling_result = self.track_roi_sampler.sample(\n                assign_result,\n                mask_preds[i][:self.num_proposals].detach(),\n                gt_masks[i])\n            key_sampling_results.append(sampling_result)\n\n            ref_assign_result = self.track_roi_assigner.assign(\n                ref_scaled_mask_preds[i][:self.num_proposals].detach(), ref_cls_scores[i][:self.num_proposals, :self.num_thing_classes].detach(),\n                ref_gt_masks[i], ref_gt_labels[i], img_meta=ref_img_metas_new[i])\n            ref_sampling_result = self.track_roi_sampler.sample(\n                ref_assign_result,\n                ref_mask_preds[i][:self.num_proposals].detach(),\n                ref_gt_masks[i])\n            ref_sampling_results.append(ref_sampling_result)\n\n        # roi feature embeddings\n        key_masks = [res.pos_gt_masks for res in key_sampling_results]\n        for i in range(len(key_masks)):\n            key_masks[i] = F.interpolate(key_masks[i].unsqueeze(0),\n                                        size=(img_h, img_w), mode=\"bilinear\", align_corners=False).squeeze(0)\n            key_masks[i] = (key_masks[i].sigmoid() > 0.5).float()\n\n        key_feats = self._track_forward(x, key_masks)\n\n        # roi feature embeddings\n        ref_masks = [res.pos_gt_masks for res in ref_sampling_results]\n        for i in range(len(ref_masks)):\n            ref_masks[i] = F.interpolate(ref_masks[i].unsqueeze(0),\n                                        size=(img_h, img_w), mode=\"bilinear\", align_corners=False).squeeze(0)\n            ref_masks[i] = (ref_masks[i].sigmoid() > 0.5).float()\n\n        ref_feats = self._track_forward(x_ref, ref_masks)\n\n        match_feats = self.track_head.match(key_feats, ref_feats,\n                                            key_sampling_results,\n                                            ref_sampling_results)\n\n        asso_targets = self.track_head.get_track_targets(\n            gt_match_indices, key_sampling_results, ref_sampling_results)\n        loss_track = self.track_head.loss(*match_feats, *asso_targets)\n\n        losses_ref = self.add_ref_loss(losses_ref)\n        ref_rpn_losses = self.add_ref_rpn_loss(ref_rpn_losses)\n\n        losses.update(ref_rpn_losses)\n        losses.update(rpn_losses)\n        losses.update(losses_ref)\n        losses.update(loss_track)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None, **kwargs):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n\n        # whether is the first frame for such clips\n        # assert 'city' in img_metas[0]['filename'] and 'iid' in img_metas[0]\n        if \"city\" in img_metas[0]['filename']:\n            iid = img_metas[0]['iid']\n            fid = iid % 10000\n            is_first = (fid == 1)\n        else:\n            iid = kwargs['img_id'][0].item()\n            fid = iid % 10000\n            is_first = (fid == 0)\n\n        if is_first:\n            self.init_tracker()\n            self.obj_feats_memory = None\n            self.x_feats_memory = None\n            self.mask_preds_memory = None\n\n        # for current frame\n        x = self.extract_feat(img)\n        # current frame inference\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        if self.link_previous:\n            cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test_with_previous(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                previous_obj_feats=self.obj_feats_memory,\n                previous_mask_preds=self.mask_preds_memory,\n                previous_x_feats=self.x_feats_memory,\n            )\n            self.obj_feats_memory = query_output\n            self.x_feats_memory = x_feats\n            self.mask_preds_memory = scaled_mask_preds\n        else:\n            cur_segm_results, query_output, cls_scores, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas)\n\n        # for tracking part\n        _, segm_result, mask_preds, panoptic_result = cur_segm_results[0]\n        panoptic_seg, segments_info = panoptic_result\n\n        if self.semantic_filter:\n            seg_preds = torch.nn.functional.interpolate(seg_preds, panoptic_seg.shape, mode='bilinear', align_corners=False)\n            seg_preds = seg_preds.sigmoid()\n            seg_out = seg_preds.argmax(1)\n            semantic_thing = (seg_out < self.num_thing_classes).to(dtype=torch.float32)\n        else:\n            semantic_thing = 1.\n\n        # get sorted tracking thing ids, labels, masks, score for tracking\n        things_index_for_tracking, things_labels_for_tracking, thing_masks_for_tracking, things_score_for_tracking = \\\n            self.get_things_id_for_tracking(panoptic_seg, segments_info)\n        things_labels_for_tracking = torch.Tensor(things_labels_for_tracking).to(cls_scores.device).long()\n        if len(things_labels_for_tracking) > 0:\n            things_bbox_for_tracking = torch.zeros((len(things_score_for_tracking), 5),\n                                                   dtype=torch.float, device=x_feats.device)\n            things_bbox_for_tracking[:, 4] = torch.tensor(things_score_for_tracking,\n                                                          device=things_bbox_for_tracking.device)\n\n            thing_masks_for_tracking_final = []\n            for mask in thing_masks_for_tracking:\n                thing_masks_for_tracking_final.append(torch.Tensor(mask).unsqueeze(0).to(\n                    x_feats.device).float())\n            thing_masks_for_tracking_final = torch.cat(thing_masks_for_tracking_final, 0)\n            thing_masks_for_tracking = thing_masks_for_tracking_final\n            thing_masks_for_tracking_with_semantic_filter = thing_masks_for_tracking_final * semantic_thing\n\n        if len(things_labels_for_tracking) == 0:\n            track_feats = None\n        else:\n            # tracking embedding features\n            track_feats = self._track_forward(x, thing_masks_for_tracking_with_semantic_filter)\n\n        if track_feats is not None:\n            # assert len(things_id_for_tracking) == len(things_labels_for_tracking)\n            things_bbox_for_tracking[:, :4] = torch.tensor(tensor_mask2box(thing_masks_for_tracking_with_semantic_filter),\n                                                           device=things_bbox_for_tracking.device)\n            bboxes, labels, ids = self.tracker.match(\n                bboxes=things_bbox_for_tracking,\n                labels=things_labels_for_tracking,\n                track_feats=track_feats,\n                frame_id=fid)\n            ids = ids + 1\n            ids[ids == -1] = 0\n        else:\n            ids = []\n\n        track_maps = self.generate_track_id_maps(ids, thing_masks_for_tracking, panoptic_seg)\n\n        semantic_map = self.get_semantic_seg(panoptic_seg, segments_info)\n\n        from scripts.visualizer import trackmap2rgb, cityscapes_cat2rgb, draw_bbox_on_img\n        vis_tracker = trackmap2rgb(track_maps)\n        vis_sem = cityscapes_cat2rgb(semantic_map)\n        if len(things_labels_for_tracking):\n            vis_tracker = draw_bbox_on_img(vis_tracker, things_bbox_for_tracking.cpu().numpy())\n\n        # Visualization end\n        return semantic_map, track_maps, None, vis_sem, vis_tracker\n\n\n    def _track_forward(self, x, mask_pred):\n        \"\"\"Track head forward function used in both training and testing.\n        We use mask pooling to get the fine grain features\"\"\"\n        if not self.training:\n            mask_pred = [mask_pred]\n        bbox_list = batch_mask2boxlist(mask_pred)\n        track_rois = bboxlist2roi(bbox_list)\n        track_rois = track_rois.clamp(min=0.0)\n        track_feats = self.track_roi_extractor(x[:self.track_roi_extractor.num_inputs], track_rois)\n        track_feats = self.track_head(track_feats)\n\n        return track_feats\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def get_things_id_for_tracking(self, panoptic_seg, seg_infos):\n        idxs = []\n        labels = []\n        masks = []\n        score = []\n        for segment in seg_infos:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                idxs.append(segment[\"instance_id\"])\n                labels.append(segment['category_id'])\n                score.append(segment['score'])\n        return idxs, labels, masks, score\n\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred = torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg\n\n    def generate_track_id_maps(self, ids, masks, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        if len(ids) == 0:\n            return final_id_maps\n        # assert len(things_mask_results) == len(track_results)\n        masks = masks.bool()\n        for i, id in enumerate(ids):\n            mask = masks[i].cpu().numpy()\n            final_id_maps[mask] = id\n        return final_id_maps\n\n    def add_ref_loss(self, loss_dict):\n        track_loss ={}\n        for k, v in loss_dict.items():\n            track_loss[str(k)+\"_ref\"] = v\n        return track_loss\n\n    def add_ref_rpn_loss(self, loss_dict):\n        ref_rpn_loss = {}\n        for k, v in loss_dict.items():\n            ref_rpn_loss[str(k) +\"_ref_rpn\"] = v\n        return ref_rpn_loss"
  },
  {
    "path": "knet/video/knet_track_head.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes\n\n\n@DETECTORS.register_module()\nclass VideoKNetFuseTrack(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 **kwargs):\n        super(VideoKNetFuseTrack, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_head = build_head(track_head)\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        gt_pids_list =[]\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) + 1 if i in ref_ids else 0 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n        gt_pids = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_labels_gt, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        ref_rpn_results = self.rpn_head.forward_train(x_ref, ref_img_metas_new, ref_gt_masks,\n                                                  ref_labels_gt, ref_gt_sem_seg,\n                                                  ref_gt_sem_cls)\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_rpn_losses, ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores) = ref_rpn_results\n\n        losses, sample_results, object_feats = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_pids=gt_pids,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        ref_losses, ref_sample_results, ref_object_feats = self.roi_head.forward_train(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas,\n            ref_gt_masks,\n            ref_gt_labels,\n            gt_bboxes=ref_gt_bboxes,\n            gt_bboxes_ignore=ref_gt_bboxes_ignore,\n            gt_sem_seg=ref_gt_sem_seg,\n            gt_sem_cls=ref_gt_sem_cls,\n            imgs_whwh=None)\n        proposals_nums = [self.roi_head.num_proposals] * img.size()[0]\n        ref_proposals_nums = proposals_nums\n\n        object_feats, ref_object_feats = self.pack_things_object(object_feats, ref_object_feats)\n        match_score = self.track_head(object_feats, ref_object_feats, proposals_nums, ref_proposals_nums)\n        track_loss = self.track_head.loss(match_score, sample_results)\n\n        # format the loss\n        ref_rpn_losses = self.add_ref_rpn_loss(ref_rpn_losses)\n        ref_losses = self.add_ref_rpn_loss(ref_losses)\n\n        losses.update(ref_rpn_losses)\n        losses.update(ref_losses)\n        losses.update(track_loss)\n        losses.update(rpn_losses)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n\n        if ref_img is not None:\n            ref_img = ref_img[0]\n        # whether is the first frame for such clips\n        assert 'city' in img_metas[0]['filename'] and 'iid' in img_metas[0]\n        iid = img_metas[0]['iid']\n        fid = iid % 10000\n        is_first = (fid == 1)\n\n        # for current frame\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        if not is_first:\n            ref_x = self.extract_feat(ref_img)\n            ref_rpn_results = self.rpn_head.simple_test_rpn(ref_x, img_metas)\n            (ref_proposal_feats, ref_x_feats, ref_mask_preds, ref_cls_scores,\n             ref_seg_preds) = ref_rpn_results\n            x_fuse = self.combine(ref_x_feats + x_feats)\n\n        cur_segm_results, cur_object_query = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            imgs_whwh=None,\n            rescale=rescale)\n\n        bbox_result, segm_result, panoptic_result = cur_segm_results[0]\n\n        panoptic_seg, segments_info = panoptic_result\n\n        cur_results, sseg_results = self.pack_stuff_things_result(panoptic_seg, segments_info)\n\n        if is_first:\n            self.track_query = cur_object_query\n\n        if not is_first:\n            track_seg_results = self.track_roi_head.simple_test(\n                    x_fuse,\n                    self.track_query,\n                    ref_mask_preds,\n                    ref_cls_scores,\n                    img_metas,\n                    imgs_whwh=None,\n                    rescale=rescale\n            )\n            bbox_result, segm_result, panoptic_result = track_seg_results[0]\n            track_panoptic_seg, track_segments_info = panoptic_result\n            track_results, ref_sseg_results = self.pack_stuff_things_result(track_panoptic_seg, track_segments_info)\n\n            # update the tracking query\n            self.track_query = cur_object_query\n\n        if is_first:\n            self.tracker.reset_all()\n            init_track_results = self.tracker.init_track(cur_results)\n            track_maps = self.generate_track_id_maps(init_track_results, panoptic_seg)\n\n        elif not is_first:\n            results = self.tracker.step(cur_results, track_results)\n            track_maps = self.generate_track_id_maps(results, panoptic_seg)\n\n        return cur_segm_results, track_maps, sseg_results\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def add_track_loss(self, loss_dict):\n        track_loss ={}\n        for k,v in loss_dict.items():\n            track_loss[str(k)+\"_track\"] = v\n        return track_loss\n\n    def add_ref_rpn_loss(self, loss_dict):\n        ref_rpn_loss = {}\n        for k,v in loss_dict.items():\n            ref_rpn_loss[str(k) +\"_ref\"] = v\n        return ref_rpn_loss\n\n    def pack_stuff_things_result(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                scores.append(segment[\"score\"])\n                # for things to shift the labels\n                # (n - c)\n                semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n\n        results[\"masks\"] = np.array(masks)  # (N)\n        results[\"scores\"] = np.array(scores)  # (N,H,W)\n\n        return results, semantic_seg\n\n    def generate_track_id_maps(self, track_results, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        # assert len(things_mask_results) == len(track_results)\n        for track in track_results:\n            id = track[\"tracking_id\"]\n            mask = track[\"mask\"]\n            final_id_maps[mask] = id\n        return final_id_maps"
  },
  {
    "path": "knet/video/knet_track_head_roi_align.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes\n\n\n@DETECTORS.register_module()\nclass VideoKNetFuseROITrack(BaseDetector):\n    \"\"\"\n        Simple Extension of KNet to Video KNet by the implementation of VPSFuse Net.\n    \"\"\"\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_head=None,\n                 extra_neck=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cityscapes=False,\n                 **kwargs):\n        super(VideoKNetFuseROITrack, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if extra_neck is not None:\n            self.extra_neck = build_neck(extra_neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        if track_head is not None:\n            self.track_head = build_head(track_head)\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      ref_gt_instance_ids=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n        assert gt_instance_ids is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_gt_instance_id_list = []\n        for ref_gt_instance_id in ref_gt_instance_ids:\n            ref_gt_instance_id_list.append(ref_gt_instance_id[:,1].long())\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        gt_pids_list =[]\n        for i in range(len(ref_gt_instance_id_list)):\n            ref_ids = ref_gt_instance_id_list[i].cpu().data.numpy().tolist()\n            gt_ids = gt_instance_ids[i].cpu().data.numpy().tolist()\n            gt_pids = [ref_ids.index(i) + 1 if i in ref_ids else 0 for i in gt_ids]\n            gt_pids_list.append(torch.LongTensor([gt_pids]).to(img.device)[0])\n        gt_pids = gt_pids_list\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_labels_gt, ref_semantic_seg_gt)\n\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        ref_rpn_results = self.rpn_head.forward_train(x_ref, ref_img_metas_new, ref_gt_masks,\n                                                  ref_labels_gt, ref_gt_sem_seg,\n                                                  ref_gt_sem_cls)\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_rpn_losses, ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores) = ref_rpn_results\n\n        losses, sample_results, object_feats, mask_preds = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_pids=gt_pids,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        ref_losses, ref_sample_results, ref_object_feats, ref_mask_preds = self.roi_head.forward_train(\n            ref_x_feats,\n            ref_proposal_feats,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas,\n            ref_gt_masks,\n            ref_gt_labels,\n            gt_bboxes=ref_gt_bboxes,\n            gt_bboxes_ignore=ref_gt_bboxes_ignore,\n            gt_sem_seg=ref_gt_sem_seg,\n            gt_sem_cls=ref_gt_sem_cls,\n            imgs_whwh=None)\n        proposals_nums = [self.roi_head.num_proposals] * img.size()[0]\n        ref_proposals_nums = proposals_nums\n\n        thing_mask_preds, ref_thing_mask_preds = self.pack_things_masks(mask_preds, ref_mask_preds)\n        match_score = self.track_head(x, x_ref, thing_mask_preds, ref_thing_mask_preds, proposals_nums, ref_proposals_nums)\n\n        track_loss = self.track_head.loss(match_score, sample_results)\n\n        # format the loss\n        ref_rpn_losses = self.add_ref_rpn_loss(ref_rpn_losses)\n        ref_losses = self.add_ref_rpn_loss(ref_losses)\n\n        losses.update(ref_rpn_losses)\n        losses.update(ref_losses)\n        losses.update(track_loss)\n        losses.update(rpn_losses)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n\n        if ref_img is not None:\n            ref_img = ref_img[0]\n        # whether is the first frame for such clips\n        assert 'city' in img_metas[0]['filename'] and 'iid' in img_metas[0]\n        iid = img_metas[0]['iid']\n        fid = iid % 10000\n        is_first = (fid == 1)\n\n        # for current frame\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        if not is_first:\n            ref_x = self.extract_feat(ref_img)\n            ref_rpn_results = self.rpn_head.simple_test_rpn(ref_x, img_metas)\n            (ref_proposal_feats, ref_x_feats, ref_mask_preds, ref_cls_scores,\n             ref_seg_preds) = ref_rpn_results\n            x_fuse = self.combine(ref_x_feats + x_feats)\n\n        cur_segm_results, cur_object_query = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            imgs_whwh=None,\n            rescale=rescale)\n\n        bbox_result, segm_result, panoptic_result = cur_segm_results[0]\n\n        panoptic_seg, segments_info = panoptic_result\n\n        cur_results, sseg_results = self.pack_stuff_things_result(panoptic_seg, segments_info)\n\n        if is_first:\n            self.track_query = cur_object_query\n\n        if not is_first:\n            track_seg_results = self.track_roi_head.simple_test(\n                    x_fuse,\n                    self.track_query,\n                    ref_mask_preds,\n                    ref_cls_scores,\n                    img_metas,\n                    imgs_whwh=None,\n                    rescale=rescale\n            )\n            bbox_result, segm_result, panoptic_result = track_seg_results[0]\n            track_panoptic_seg, track_segments_info = panoptic_result\n            track_results, ref_sseg_results = self.pack_stuff_things_result(track_panoptic_seg, track_segments_info)\n\n            # update the tracking query\n            self.track_query = cur_object_query\n\n        if is_first:\n            self.tracker.reset_all()\n            init_track_results = self.tracker.init_track(cur_results)\n            track_maps = self.generate_track_id_maps(init_track_results, panoptic_seg)\n\n        elif not is_first:\n            results = self.tracker.step(cur_results, track_results)\n            track_maps = self.generate_track_id_maps(results, panoptic_seg)\n\n        return cur_segm_results, track_maps, sseg_results\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def pack_things_object(self, object_feats, ref_object_feats):\n        object_feats, ref_object_feats = object_feats.squeeze(-1).squeeze(-1), ref_object_feats.squeeze(-1).squeeze(-1)\n        thing_object_feats = torch.split(object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_object_feats = torch.split(ref_object_feats, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_object_feats, ref_thing_object_feats\n\n    def pack_things_masks(self, mask_pred, ref_mask_pred):\n        thing_mask_pred = torch.split(mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        ref_thing_thing_mask_pred= torch.split(ref_mask_pred, [self.roi_head.num_proposals, self.num_stuff_classes], dim=1)[0]\n        return thing_mask_pred, ref_thing_thing_mask_pred\n\n    def add_track_loss(self, loss_dict):\n        track_loss ={}\n        for k,v in loss_dict.items():\n            track_loss[str(k)+\"_track\"] = v\n        return track_loss\n\n    def add_ref_rpn_loss(self, loss_dict):\n        ref_rpn_loss = {}\n        for k,v in loss_dict.items():\n            ref_rpn_loss[str(k) +\"_ref\"] = v\n        return ref_rpn_loss\n\n    def pack_stuff_things_result(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                scores.append(segment[\"score\"])\n                # for things to shift the labels\n                # (n - c)\n                semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n\n        results[\"masks\"] = np.array(masks)  # (N)\n        results[\"scores\"] = np.array(scores)  # (N,H,W)\n\n        return results, semantic_seg\n\n    def generate_track_id_maps(self, track_results, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        # assert len(things_mask_results) == len(track_results)\n        for track in track_results:\n            id = track[\"tracking_id\"]\n            mask = track[\"mask\"]\n            final_id_maps[mask] = id\n        return final_id_maps"
  },
  {
    "path": "knet/video/knet_uni_track.py",
    "content": "import warnings\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import BaseDetector\nfrom mmdet.models.builder import build_head, build_neck, build_backbone\nfrom knet.det.utils import sem2ins_masks, sem2ins_masks_cityscapes\nfrom unitrack.mask import MaskAssociationTracker\n\n\n@DETECTORS.register_module()\nclass VideoKNetUniTrack(BaseDetector):\n    def __init__(self,\n                 backbone,\n                 neck=None,\n                 rpn_head=None,\n                 roi_head=None,\n                 track_roi_head=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 kitti_step=False,\n                 cityscapes=False,\n                 uni_tracker_cfg=None,\n                 **kwargs):\n        super(VideoKNetUniTrack, self).__init__(init_cfg)\n\n        if pretrained:\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            backbone.pretrained = pretrained\n        self.backbone = build_backbone(backbone)\n\n        if neck is not None:\n            self.neck = build_neck(neck)\n\n        if rpn_head is not None:\n            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None\n            rpn_head_ = rpn_head.copy()\n            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)\n            self.rpn_head = build_head(rpn_head_)\n\n        if roi_head is not None:\n            # update train and test cfg here for now\n            # TODO: refactor assigner & sampler\n            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None\n            roi_head.update(train_cfg=rcnn_train_cfg)\n            roi_head.update(test_cfg=test_cfg.rcnn)\n            roi_head.pretrained = pretrained\n            self.roi_head = build_head(roi_head)\n\n        self.tracker = MaskAssociationTracker(uni_tracker_cfg)\n        self.img0 = None\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.ignore_label = ignore_label\n        self.cityscapes = cityscapes  # whether to train the cityscape panoptic segmentation\n        self.kitti_step = kitti_step  # whether to use kitti step dataset\n\n    def preprocess_gt_masks(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by zero when forming a batch\n                # need to convert them from 0 to ignore\n                gt_semantic_seg[\n                i, :, img_metas[i]['img_shape'][0]:, :] = self.ignore_label\n                gt_semantic_seg[\n                i, :, :, img_metas[i]['img_shape'][1]:] = self.ignore_label\n                if self.cityscapes:\n                    sem_labels, sem_seg = sem2ins_masks_cityscapes(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes)\n                else:\n                    sem_labels, sem_seg = sem2ins_masks(\n                        gt_semantic_seg[i],\n                        ignore_label=self.ignore_label,\n                        label_shift=self.num_thing_classes,\n                        thing_label_in_seg=self.thing_label_in_seg)\n\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),  # downsample to 1/4 resolution\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        return gt_masks_tensor, gt_sem_cls, gt_sem_seg\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      gt_semantic_seg=None,\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_labels=None,\n                      ref_gt_masks=None,\n                      ref_gt_semantic_seg=None,\n                      proposals=None,\n                      **kwargs):\n        \"\"\"Forward function of SparseR-CNN-like network in train stage.\n\n        Args:\n            img (Tensor): of shape (N, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n            img_metas (list[dict]): list of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                :class:`mmdet.datasets.pipelines.Collect`.\n            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with\n                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.\n            gt_labels (list[Tensor]): class indices corresponding to each box\n            gt_bboxes_ignore (None | list[Tensor): specify which bounding\n                boxes can be ignored when computing the loss.\n            gt_masks (List[Tensor], optional) : Segmentation masks for\n                each box. But we don't support it in this architecture.\n            proposals (List[Tensor], optional): override rpn proposals with\n                custom proposals. Use when `with_rpn` is False.\n\n            # This is for video only:\n            ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images.\n                Typically these should be mean centered and std scaled.\n                2 denotes there is two reference images for each input image.\n\n            ref_img_metas (list[list[dict]]): The first list only has one\n                element. The second list contains reference image information\n                dict where each dict has: 'img_shape', 'scale_factor', 'flip',\n                and may also contain 'filename', 'ori_shape', 'pad_shape', and\n                'img_norm_cfg'. For details on the values of these keys see\n                `mmtrack/datasets/pipelines/formatting.py:VideoCollect`.\n\n            ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The\n                Tensor contains ground truth bboxes for each reference image\n                with shape (num_all_ref_gts, 5) in\n                [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id\n                start from 0, and denotes the id of reference image for each\n                key image.\n\n            ref_gt_labels (list[Tensor]): The list only has one Tensor. The\n                Tensor contains class indices corresponding to each reference\n                box with shape (num_all_ref_gts, 2) in\n                [ref_img_id, class_indice].\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        batch_input_shape = tuple(img[0].size()[-2:])\n        for img_meta in img_metas:\n            img_meta['batch_input_shape'] = batch_input_shape\n\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n\n        # preprocess the reference images\n        ref_img = ref_img.squeeze(1)  # (b,3,h,w)\n        ref_masks_gt = []\n        for ref_gt_mask in ref_gt_masks:\n            ref_masks_gt.append(ref_gt_mask[0])\n        ref_labels_gt = []\n        for ref_gt_label in ref_gt_labels:\n            ref_labels_gt.append(ref_gt_label[:, 1].long())\n        ref_gt_labels = ref_labels_gt\n        ref_semantic_seg_gt = ref_gt_semantic_seg.squeeze(1)\n\n        ref_img_metas_new = []\n        for ref_img_meta in ref_img_metas:\n            ref_img_meta[0]['batch_input_shape'] = batch_input_shape\n            ref_img_metas_new.append(ref_img_meta[0])\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks, gt_sem_cls, gt_sem_seg = self.preprocess_gt_masks(img_metas, gt_masks, gt_labels, gt_semantic_seg)\n\n        ref_gt_masks, ref_gt_sem_cls, ref_gt_sem_seg = self.preprocess_gt_masks(ref_img_metas_new,\n                                                                    ref_masks_gt, ref_labels_gt, ref_semantic_seg_gt)\n        x = self.extract_feat(img)\n        x_ref = self.extract_feat(ref_img)\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n\n        ref_rpn_results = self.rpn_head.forward_train(x_ref, ref_img_metas_new, ref_gt_masks,\n                                                  ref_labels_gt, ref_gt_sem_seg,\n                                                  ref_gt_sem_cls)\n\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        (ref_rpn_losses, ref_proposal_feats, ref_x_feats, ref_mask_preds,\n         ref_cls_scores) = ref_rpn_results\n\n        x_fuse = self.combine(ref_x_feats + x_feats)\n\n        losses, cur_object_query = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        track_query_loss = self.track_roi_head.forward_train(\n            x_fuse,\n            cur_object_query,\n            ref_mask_preds,\n            ref_cls_scores,\n            ref_img_metas_new,\n            ref_gt_masks,\n            ref_gt_labels,\n            gt_sem_seg=ref_gt_sem_seg,\n            gt_sem_cls=ref_gt_sem_cls,\n            imgs_whwh=None\n        )\n\n        track_query_loss = self.add_track_loss(track_query_loss)\n        ref_rpn_losses = self.add_ref_rpn_loss(ref_rpn_losses)\n        # single frame loss\n        # query track loss for reference frame\n        losses.update(ref_rpn_losses)\n        losses.update(rpn_losses)\n        losses.update(track_query_loss)\n\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False, ref_img=None):\n        \"\"\"Test function without test time augmentation.\n\n        Args:\n            imgs (list[torch.Tensor]): List of multiple images\n            img_metas (list[dict]): List of image information.\n            rescale (bool): Whether to rescale the results.\n                Defaults to False.\n\n        Returns:\n            list[list[np.ndarray]]: BBox results of each image and classes.\n                The outer list corresponds to each image. The inner list\n                corresponds to each class.\n        \"\"\"\n        if ref_img is not None:\n            ref_img = ref_img[0]\n        # whether is the first frame for such clips\n        assert 'city' in img_metas[0]['filename'] and 'iid' in img_metas[0]\n        iid = img_metas[0]['iid']\n        fid = iid % 10000\n        is_first = (fid == 1)\n\n        # for current frame\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n\n        # Changed from the notation above, need further check.\n        cur_segm_results, object_feats, cls_score, mask_preds, scaled_mask_preds = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas)\n\n        bbox_result, segm_result, thing_mask_preds, panoptic_result = cur_segm_results[0]\n\n        panoptic_seg, segments_info = panoptic_result\n\n        cur_results, sseg_results = self.pack_stuff_things_result(panoptic_seg, segments_info)\n\n        if is_first:\n            self.img0 = img\n            self.tracker.reset_all()\n            if len(cur_results[\"masks\"]) == 0:\n                track_maps = np.zeros(panoptic_seg.shape)\n            else:\n                init_track_results = self.tracker.update(img, self.img0, cur_results[\"masks\"])\n                track_maps = self.generate_track_id_maps(init_track_results, panoptic_seg)\n\n        else:\n            if len(cur_results[\"masks\"]) == 0:\n                track_maps = np.zeros(panoptic_seg.shape)\n            else:\n                results = self.tracker.update(img, self.img0, cur_results[\"masks\"])\n                track_maps = self.generate_track_id_maps(results, panoptic_seg)\n\n        semantic_map = self.get_semantic_seg(panoptic_seg, segments_info)\n\n        from scripts.visualizer import trackmap2rgb, cityscapes_cat2rgb, draw_bbox_on_img\n        vis_tracker = trackmap2rgb(track_maps)\n        vis_sem = cityscapes_cat2rgb(semantic_map)\n\n        return semantic_map, track_maps, None,vis_sem, vis_tracker\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        # roi_head\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n\n    def extract_feat(self, img):\n        \"\"\"Directly extract features from the backbone+neck.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    @property\n    def with_rpn(self):\n        \"\"\"bool: whether the detector has RPN\"\"\"\n        return hasattr(self, 'rpn_head') and self.rpn_head is not None\n\n    @property\n    def with_roi_head(self):\n        \"\"\"bool: whether the detector has a RoI head\"\"\"\n        return hasattr(self, 'roi_head') and self.roi_head is not None\n\n    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):\n        \"\"\"Test with augmentations.\n\n        If rescale is False, then returned bboxes and masks will fit the scale\n        of imgs[0].\n        \"\"\"\n        pass\n\n    def add_track_loss(self, loss_dict):\n        track_loss ={}\n        for k, v in loss_dict.items():\n            track_loss[str(k)+\"_track\"] = v\n        return track_loss\n\n    def add_ref_rpn_loss(self, loss_dict):\n        ref_rpn_loss = {}\n        for k, v in loss_dict.items():\n            ref_rpn_loss[str(k) +\"_ref\"] = v\n        return ref_rpn_loss\n\n    def pack_stuff_things_result(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                thing_mask = panoptic_seg == segment[\"id\"]\n                masks.append(thing_mask)\n                scores.append(segment[\"score\"])\n                # for things to shift the labels\n                # (n - c)\n                semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n\n        results[\"masks\"] = np.array(masks)  # (N)\n        results[\"scores\"] = np.array(scores)  # (N,H,W)\n\n        return results, semantic_seg\n\n    def generate_track_id_maps(self, track_results, panopitc_seg_maps):\n        final_id_maps = np.zeros(panopitc_seg_maps.shape)\n        # print(\" current track results: \", len(track_results))\n        for track in track_results:\n            id = track.track_id\n            mask = track.mask\n            final_id_maps[mask] = id\n        return final_id_maps\n\n    def get_semantic_seg(self, panoptic_seg, segments_info):\n        results = {}\n        masks = []\n        scores = []\n        kitti_step2cityscpaes = [11, 13]\n        semantic_seg = np.zeros(panoptic_seg.shape)\n        for segment in segments_info:\n            if segment['isthing'] == True:\n                if self.kitti_step:\n                    cat_cur = kitti_step2cityscpaes[segment[\"category_id\"]]\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] + 11\n            else:\n                # for stuff (0- n-1)\n                if self.kitti_step:\n                    cat_cur = segment[\"category_id\"]\n                    cat_cur -= 1\n                    offset = 0\n                    for thing_id in kitti_step2cityscpaes:\n                        if cat_cur + offset >= thing_id:\n                            offset += 1\n                    cat_cur += offset\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = cat_cur\n                else:\n                    semantic_seg[panoptic_seg == segment[\"id\"]] = segment[\"category_id\"] - 1\n        return semantic_seg"
  },
  {
    "path": "knet/video/mask_hungarian_assigner.py",
    "content": "import numpy as np\nimport torch\nfrom mmdet.core import AssignResult, BaseAssigner, reduce_mean\nfrom mmdet.core.bbox.builder import BBOX_ASSIGNERS\nfrom mmdet.core.bbox.match_costs.builder import MATCH_COST, build_match_cost\n\ntry:\n    from scipy.optimize import linear_sum_assignment\nexcept ImportError:\n    linear_sum_assignment = None\n\n\n@MATCH_COST.register_module()\nclass DiceCost(object):\n    \"\"\"DiceCost.\n\n     Args:\n         weight (int | float, optional): loss_weight\n         pred_act (bool): Whether to activate the prediction\n            before calculating cost\n\n     Examples:\n         >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost\n         >>> import torch\n         >>> self = BBoxL1Cost()\n         >>> bbox_pred = torch.rand(1, 4)\n         >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])\n         >>> factor = torch.tensor([10, 8, 10, 8])\n         >>> self(bbox_pred, gt_bboxes, factor)\n         tensor([[1.6172, 1.6422]])\n    \"\"\"\n\n    def __init__(self,\n                 weight=1.,\n                 pred_act=False,\n                 act_mode='sigmoid',\n                 eps=1e-3):\n        self.weight = weight\n        self.pred_act = pred_act\n        self.act_mode = act_mode\n        self.eps = eps\n\n    def dice_loss(cls, input, target, eps=1e-3):\n        input = input.reshape(input.size()[0], -1)\n        target = target.reshape(target.size()[0], -1).float()\n        # einsum saves 10x memory\n        # a = torch.sum(input[:, None] * target[None, ...], -1)\n        a = torch.einsum('nh,mh->nm', input, target)\n        b = torch.sum(input * input, 1) + eps\n        c = torch.sum(target * target, 1) + eps\n        d = (2 * a) / (b[:, None] + c[None, ...])\n        # 1 is a constance that will not affect the matching, so ommitted\n        return -d\n\n    def __call__(self, mask_preds, gt_masks):\n        \"\"\"\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            gt_bboxes (Tensor): Ground truth boxes with normalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n\n        Returns:\n            torch.Tensor: bbox_cost value with weight\n        \"\"\"\n        if self.pred_act and self.act_mode == 'sigmoid':\n            mask_preds = mask_preds.sigmoid()\n        elif self.pred_act:\n            mask_preds = mask_preds.softmax(dim=0)\n        dice_cost = self.dice_loss(mask_preds, gt_masks, self.eps)\n        return dice_cost * self.weight\n\n\n@MATCH_COST.register_module()\nclass MaskCost(object):\n    \"\"\"MaskCost.\n\n    Args:\n        weight (int | float, optional): loss_weight\n    \"\"\"\n\n    def __init__(self, weight=1., pred_act=False, act_mode='sigmoid'):\n        self.weight = weight\n        self.pred_act = pred_act\n        self.act_mode = act_mode\n\n    def __call__(self, cls_pred, target):\n        \"\"\"\n        Args:\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n\n        Returns:\n            torch.Tensor: cls_cost value with weight\n        \"\"\"\n        if self.pred_act and self.act_mode == 'sigmoid':\n            cls_pred = cls_pred.sigmoid()\n        elif self.pred_act:\n            cls_pred = cls_pred.softmax(dim=0)\n        num_proposals = cls_pred.shape[0]\n        num_gts, H, W = target.shape\n        # flatten_cls_pred = cls_pred.view(num_proposals, -1)\n        # eingum is ~10 times faster than matmul\n        pos_cost = torch.einsum('nhw,mhw->nm', cls_pred, target)\n        neg_cost = torch.einsum('nhw,mhw->nm', 1 - cls_pred, 1 - target)\n        # flatten_target = target.view(num_gts, -1).t()\n        # pos_cost = flatten_cls_pred.matmul(flatten_target)\n        # neg_cost = (1 - flatten_cls_pred).matmul(1 - flatten_target)\n        cls_cost = -(pos_cost + neg_cost) / (H * W)\n        return cls_cost * self.weight\n\n\n@BBOX_ASSIGNERS.register_module()\nclass MaskHungarianAssigner(BaseAssigner):\n    \"\"\"Computes one-to-one matching between predictions and ground truth.\n\n    This class computes an assignment between the targets and the predictions\n    based on the costs. The costs are weighted sum of three components:\n    classfication cost, regression L1 cost and regression iou cost. The\n    targets don't include the no_object, so generally there are more\n    predictions than targets. After the one-to-one matching, the un-matched\n    are treated as backgrounds. Thus each query prediction will be assigned\n    with `0` or a positive integer indicating the ground truth index:\n\n    - 0: negative sample, no assigned gt\n    - positive integer: positive sample, index (1-based) of assigned gt\n\n    Args:\n        cls_weight (int | float, optional): The scale factor for classification\n            cost. Default 1.0.\n        bbox_weight (int | float, optional): The scale factor for regression\n            L1 cost. Default 1.0.\n        iou_weight (int | float, optional): The scale factor for regression\n            iou cost. Default 1.0.\n        iou_calculator (dict | optional): The config for the iou calculation.\n            Default type `BboxOverlaps2D`.\n        iou_mode (str | optional): \"iou\" (intersection over union), \"iof\"\n                (intersection over foreground), or \"giou\" (generalized\n                intersection over union). Default \"giou\".\n    \"\"\"\n\n    def __init__(self,\n                 cls_cost=dict(type='ClassificationCost', weight=1.),\n                 mask_cost=dict(type='SigmoidCost', weight=1.0),\n                 dice_cost=dict(),\n                 boundary_cost=None,\n                 topk=1):\n        self.cls_cost = build_match_cost(cls_cost)\n        self.mask_cost = build_match_cost(mask_cost)\n        self.dice_cost = build_match_cost(dice_cost)\n        if boundary_cost is not None:\n            self.boundary_cost = build_match_cost(boundary_cost)\n        else:\n            self.boundary_cost = None\n        self.topk = topk\n\n    def assign(self,\n               bbox_pred,\n               cls_pred,\n               gt_bboxes,\n               gt_labels,\n               img_meta=None,\n               gt_bboxes_ignore=None,\n               eps=1e-7):\n        \"\"\"Computes one-to-one matching based on the weighted costs.\n\n        This method assign each query prediction to a ground truth or\n        background. The `assigned_gt_inds` with -1 means don't care,\n        0 means negative sample, and positive number is the index (1-based)\n        of assigned gt.\n        The assignment is done in the following steps, the order matters.\n\n        1. assign every prediction to -1\n        2. compute the weighted costs\n        3. do Hungarian matching on CPU based on the costs\n        4. assign all to 0 (background) first, then for each matched pair\n           between predictions and gts, treat this prediction as foreground\n           and assign the corresponding gt index (plus 1) to it.\n\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_bboxes (Tensor): Ground truth boxes with unnormalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n            img_meta (dict): Meta information for current image.\n            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are\n                labelled as `ignored`. Default None.\n            eps (int | float, optional): A value added to the denominator for\n                numerical stability. Default 1e-7.\n\n        Returns:\n            :obj:`AssignResult`: The assigned result.\n        \"\"\"\n        assert gt_bboxes_ignore is None, \\\n            'Only case when gt_bboxes_ignore is None is supported.'\n        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)\n\n        # 1. assign -1 by default\n        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),\n                                              -1,\n                                              dtype=torch.long)\n        assigned_labels = bbox_pred.new_full((num_bboxes, ),\n                                             -1,\n                                             dtype=torch.long)\n        if num_gts == 0 or num_bboxes == 0:\n            # No ground truth or boxes, return empty assignment\n            if num_gts == 0:\n                # No ground truth, assign all to background\n                assigned_gt_inds[:] = 0\n            return AssignResult(\n                num_gts, assigned_gt_inds, None, labels=assigned_labels)\n\n        # 2. compute the weighted costs\n        # classification and bboxcost.\n        if self.cls_cost.weight != 0 and cls_pred is not None:\n            cls_cost = self.cls_cost(cls_pred, gt_labels)\n        else:\n            cls_cost = 0\n        if self.mask_cost.weight != 0:\n            reg_cost = self.mask_cost(bbox_pred, gt_bboxes)\n        else:\n            reg_cost = 0\n        if self.dice_cost.weight != 0:\n            dice_cost = self.dice_cost(bbox_pred, gt_bboxes)\n        else:\n            dice_cost = 0\n        if self.boundary_cost is not None and self.boundary_cost.weight != 0:\n            b_cost = self.boundary_cost(bbox_pred, gt_bboxes)\n        else:\n            b_cost = 0\n        cost = cls_cost + reg_cost + dice_cost + b_cost\n\n        # 3. do Hungarian matching on CPU using linear_sum_assignment\n        cost = cost.detach().cpu()\n        if linear_sum_assignment is None:\n            raise ImportError('Please run \"pip install scipy\" '\n                              'to install scipy first.')\n        if self.topk == 1:\n            matched_row_inds, matched_col_inds = linear_sum_assignment(cost)\n        else:\n            topk_matched_row_inds = []\n            topk_matched_col_inds = []\n            for i in range(self.topk):\n                matched_row_inds, matched_col_inds = linear_sum_assignment(\n                    cost)\n                topk_matched_row_inds.append(matched_row_inds)\n                topk_matched_col_inds.append(matched_col_inds)\n                cost[matched_row_inds] = 1e10\n            matched_row_inds = np.concatenate(topk_matched_row_inds)\n            matched_col_inds = np.concatenate(topk_matched_col_inds)\n\n        matched_row_inds = torch.from_numpy(matched_row_inds).to(\n            bbox_pred.device)\n        matched_col_inds = torch.from_numpy(matched_col_inds).to(\n            bbox_pred.device)\n\n        # 4. assign backgrounds and foregrounds\n        # assign all indices to backgrounds first\n        assigned_gt_inds[:] = 0\n        # assign foregrounds based on matching results\n        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1\n        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]\n        return AssignResult(\n            num_gts, assigned_gt_inds, None, labels=assigned_labels)\n\n\n@BBOX_ASSIGNERS.register_module()\nclass MaskHungarianAssignerWithEmbed(BaseAssigner):\n    \"\"\"Computes one-to-one matching between predictions and ground truth.\n\n    This class computes an assignment between the targets and the predictions\n    based on the costs. The costs are weighted sum of three components:\n    classfication cost, regression L1 cost and regression iou cost. The\n    targets don't include the no_object, so generally there are more\n    predictions than targets. After the one-to-one matching, the un-matched\n    are treated as backgrounds. Thus each query prediction will be assigned\n    with `0` or a positive integer indicating the ground truth index:\n\n    - 0: negative sample, no assigned gt\n    - positive integer: positive sample, index (1-based) of assigned gt\n\n    Args:\n        cls_weight (int | float, optional): The scale factor for classification\n            cost. Default 1.0.\n        bbox_weight (int | float, optional): The scale factor for regression\n            L1 cost. Default 1.0.\n        iou_weight (int | float, optional): The scale factor for regression\n            iou cost. Default 1.0.\n        iou_calculator (dict | optional): The config for the iou calculation.\n            Default type `BboxOverlaps2D`.\n        iou_mode (str | optional): \"iou\" (intersection over union), \"iof\"\n                (intersection over foreground), or \"giou\" (generalized\n                intersection over union). Default \"giou\".\n    \"\"\"\n\n    def __init__(self,\n                 cls_cost=dict(type='ClassificationCost', weight=1.),\n                 mask_cost=dict(type='SigmoidCost', weight=1.0),\n                 dice_cost=dict(),\n                 boundary_cost=None,\n                 topk=1):\n        self.cls_cost = build_match_cost(cls_cost)\n        self.mask_cost = build_match_cost(mask_cost)\n        self.dice_cost = build_match_cost(dice_cost)\n        if boundary_cost is not None:\n            self.boundary_cost = build_match_cost(boundary_cost)\n        else:\n            self.boundary_cost = None\n        self.topk = topk\n\n    def assign(self,\n               bbox_pred,\n               cls_pred,\n               gt_bboxes,\n               gt_labels,\n               embed_pred=None,\n               img_meta=None,\n               gt_bboxes_ignore=None,\n               eps=1e-7):\n        \"\"\"Computes one-to-one matching based on the weighted costs.\n\n        This method assign each query prediction to a ground truth or\n        background. The `assigned_gt_inds` with -1 means don't care,\n        0 means negative sample, and positive number is the index (1-based)\n        of assigned gt.\n        The assignment is done in the following steps, the order matters.\n\n        1. assign every prediction to -1\n        2. compute the weighted costs\n        3. do Hungarian matching on CPU based on the costs\n        4. assign all to 0 (background) first, then for each matched pair\n           between predictions and gts, treat this prediction as foreground\n           and assign the corresponding gt index (plus 1) to it.\n\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_bboxes (Tensor): Ground truth boxes with unnormalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n            img_meta (dict): Meta information for current image.\n            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are\n                labelled as `ignored`. Default None.\n            eps (int | float, optional): A value added to the denominator for\n                numerical stability. Default 1e-7.\n\n        Returns:\n            :obj:`AssignResult`: The assigned result.\n        \"\"\"\n        assert gt_bboxes_ignore is None, \\\n            'Only case when gt_bboxes_ignore is None is supported.'\n        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)\n\n        # 1. assign -1 by default\n        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),\n                                              -1,\n                                              dtype=torch.long)\n        assigned_labels = bbox_pred.new_full((num_bboxes, ),\n                                             -1,\n                                             dtype=torch.long)\n        if num_gts == 0 or num_bboxes == 0:\n            # No ground truth or boxes, return empty assignment\n            if num_gts == 0:\n                # No ground truth, assign all to background\n                assigned_gt_inds[:] = 0\n            return AssignResult(\n                num_gts, assigned_gt_inds, None, labels=assigned_labels)\n\n        # 2. compute the weighted costs\n        # classification and bboxcost.\n        if self.cls_cost.weight != 0 and cls_pred is not None:\n            cls_cost = self.cls_cost(cls_pred, gt_labels)\n        else:\n            cls_cost = 0\n        if self.mask_cost.weight != 0:\n            reg_cost = self.mask_cost(bbox_pred, gt_bboxes)\n        else:\n            reg_cost = 0\n        if self.dice_cost.weight != 0:\n            dice_cost = self.dice_cost(bbox_pred, gt_bboxes)\n        else:\n            dice_cost = 0\n        if self.boundary_cost is not None and self.boundary_cost.weight != 0:\n            b_cost = self.boundary_cost(bbox_pred, gt_bboxes)\n        else:\n            b_cost = 0\n        cost = cls_cost + reg_cost + dice_cost + b_cost\n\n        # 3. do Hungarian matching on CPU using linear_sum_assignment\n        cost = cost.detach().cpu()\n        if linear_sum_assignment is None:\n            raise ImportError('Please run \"pip install scipy\" '\n                              'to install scipy first.')\n        if self.topk == 1:\n            matched_row_inds, matched_col_inds = linear_sum_assignment(cost)\n        else:\n            topk_matched_row_inds = []\n            topk_matched_col_inds = []\n            for i in range(self.topk):\n                matched_row_inds, matched_col_inds = linear_sum_assignment(\n                    cost)\n                topk_matched_row_inds.append(matched_row_inds)\n                topk_matched_col_inds.append(matched_col_inds)\n                cost[matched_row_inds] = 1e10\n            matched_row_inds = np.concatenate(topk_matched_row_inds)\n            matched_col_inds = np.concatenate(topk_matched_col_inds)\n\n        matched_row_inds = torch.from_numpy(matched_row_inds).to(\n            bbox_pred.device)\n        matched_col_inds = torch.from_numpy(matched_col_inds).to(\n            bbox_pred.device)\n\n        # 4. assign backgrounds and foregrounds\n        # assign all indices to backgrounds first\n        assigned_gt_inds[:] = 0\n        # assign foregrounds based on matching results\n        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1\n        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]\n        return AssignResult(\n            num_gts, assigned_gt_inds, None, labels=assigned_labels)\n"
  },
  {
    "path": "knet/video/mask_pseudo_sampler.py",
    "content": "import torch\n\nfrom mmdet.core.bbox import BaseSampler, SamplingResult\nfrom mmdet.core.bbox.builder import BBOX_SAMPLERS\n\n\nclass MaskSamplingResult(SamplingResult):\n    \"\"\"Bbox sampling result.\n\n    Example:\n        >>> # xdoctest: +IGNORE_WANT\n        >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA\n        >>> self = SamplingResult.random(rng=10)\n        >>> print(f'self = {self}')\n        self = <SamplingResult({\n            'neg_masks': torch.Size([12, 4]),\n            'neg_inds': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12]),\n            'num_gts': 4,\n            'pos_assigned_gt_inds': tensor([], dtype=torch.int64),\n            'pos_masks': torch.Size([0, 4]),\n            'pos_inds': tensor([], dtype=torch.int64),\n            'pos_is_gt': tensor([], dtype=torch.uint8)\n        })>\n    \"\"\"\n\n    def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result,\n                 gt_flags):\n        self.pos_inds = pos_inds\n        self.neg_inds = neg_inds\n        self.pos_masks = masks[pos_inds]\n        self.neg_masks = masks[neg_inds]\n        self.pos_is_gt = gt_flags[pos_inds]\n\n        self.num_gts = gt_masks.shape[0]\n        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1\n\n        if gt_masks.numel() == 0:\n            # hack for index error case\n            assert self.pos_assigned_gt_inds.numel() == 0\n            self.pos_gt_masks = torch.empty_like(gt_masks)\n        else:\n            self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]\n\n        if assign_result.labels is not None:\n            self.pos_gt_labels = assign_result.labels[pos_inds]\n        else:\n            self.pos_gt_labels = None\n\n    @property\n    def masks(self):\n        \"\"\"torch.Tensor: concatenated positive and negative boxes\"\"\"\n        return torch.cat([self.pos_masks, self.neg_masks])\n\n    def __nice__(self):\n        data = self.info.copy()\n        data['pos_masks'] = data.pop('pos_masks').shape\n        data['neg_masks'] = data.pop('neg_masks').shape\n        parts = [f\"'{k}': {v!r}\" for k, v in sorted(data.items())]\n        body = '    ' + ',\\n    '.join(parts)\n        return '{\\n' + body + '\\n}'\n\n    @property\n    def info(self):\n        \"\"\"Returns a dictionary of info about the object.\"\"\"\n        return {\n            'pos_inds': self.pos_inds,\n            'neg_inds': self.neg_inds,\n            'pos_masks': self.pos_masks,\n            'neg_masks': self.neg_masks,\n            'pos_is_gt': self.pos_is_gt,\n            'num_gts': self.num_gts,\n            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,\n        }"
  },
  {
    "path": "knet/video/qdtrack/builder.py",
    "content": "from mmcv.utils import Registry\nfrom mmcv.cnn import build_model_from_cfg as build\n\nTRACKERS = Registry('tracker')\n\n\ndef build_tracker(cfg):\n    \"\"\"Build tracker.\"\"\"\n    return build(cfg, TRACKERS)\n"
  },
  {
    "path": "knet/video/qdtrack/losses/__init__.py",
    "content": "from .l2_loss import L2Loss\nfrom .multipos_cross_entropy_loss import MultiPosCrossEntropyLoss\n\n__all__ = ['L2Loss', 'MultiPosCrossEntropyLoss']"
  },
  {
    "path": "knet/video/qdtrack/losses/l2_loss.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom mmdet.models import LOSSES, weighted_loss\n\n\n@weighted_loss\ndef l2_loss(pred, target):\n    \"\"\"L2 loss.\n\n    Args:\n        pred (torch.Tensor): The prediction.\n        target (torch.Tensor): The learning target of the prediction.\n\n    Returns:\n        torch.Tensor: Calculated loss\n    \"\"\"\n    assert pred.size() == target.size() and target.numel() > 0\n    loss = torch.abs(pred - target)**2\n    return loss\n\n\n@LOSSES.register_module()\nclass L2Loss(nn.Module):\n    \"\"\"L2 loss.\n\n    Args:\n        reduction (str, optional): The method to reduce the loss.\n            Options are \"none\", \"mean\" and \"sum\".\n        loss_weight (float, optional): The weight of loss.\n    \"\"\"\n\n    def __init__(self,\n                 neg_pos_ub=-1,\n                 pos_margin=-1,\n                 neg_margin=-1,\n                 hard_mining=False,\n                 reduction='mean',\n                 loss_weight=1.0):\n        super(L2Loss, self).__init__()\n        self.neg_pos_ub = neg_pos_ub\n        self.pos_margin = pos_margin\n        self.neg_margin = neg_margin\n        self.hard_mining = hard_mining\n        self.reduction = reduction\n        self.loss_weight = loss_weight\n\n    def forward(self,\n                pred,\n                target,\n                weight=None,\n                avg_factor=None,\n                reduction_override=None):\n        \"\"\"Forward function.\n\n        Args:\n            pred (torch.Tensor): The prediction.\n            target (torch.Tensor): The learning target of the prediction.\n            weight (torch.Tensor, optional): The weight of loss for each\n                prediction. Defaults to None.\n            avg_factor (int, optional): Average factor that is used to average\n                the loss. Defaults to None.\n            reduction_override (str, optional): The reduction method used to\n                override the original reduction method of the loss.\n                Defaults to None.\n        \"\"\"\n        assert reduction_override in (None, 'none', 'mean', 'sum')\n        reduction = (\n            reduction_override if reduction_override else self.reduction)\n        pred, weight, avg_factor = self.update_weight(pred, target, weight, avg_factor)\n        loss_bbox = self.loss_weight * l2_loss(\n            pred, target, weight, reduction=reduction, avg_factor=avg_factor)\n        return loss_bbox\n\n    def update_weight(self, pred, target, weight, avg_factor):\n        if weight is None:\n            weight = target.new_ones(target.size())\n        invalid_inds = weight <= 0\n        target[invalid_inds] = -1\n        pos_inds = target == 1\n        neg_inds = target == 0\n\n        if self.pos_margin > 0:\n            pred[pos_inds] -= self.pos_margin\n        if self.neg_margin > 0:\n            pred[neg_inds] -= self.neg_margin\n        pred = torch.clamp(pred, min=0, max=1)\n\n        num_pos = int((target == 1).sum())\n        num_neg = int((target == 0).sum())\n        if self.neg_pos_ub > 0 and num_neg / (num_pos + 1 ) > self.neg_pos_ub:\n            num_neg = num_pos * self.neg_pos_ub\n            neg_idx = torch.nonzero(target == 0, as_tuple=False)\n\n            if self.hard_mining:\n                costs = l2_loss(\n                    pred, target, reduction='none')[neg_idx[:, 0],\n                                                    neg_idx[:, 1]].detach()\n                neg_idx = neg_idx[costs.topk(num_neg)[1], :]\n            else:\n                neg_idx = self.random_choice(neg_idx, num_neg)\n\n            new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool()\n            new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True\n\n            invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds)\n            weight[invalid_neg_inds] = 0\n\n        avg_factor = (weight > 0).sum()\n        return pred, weight, avg_factor\n\n    @staticmethod\n    def random_choice(gallery, num):\n        \"\"\"Random select some elements from the gallery.\n\n        It seems that Pytorch's implementation is slower than numpy so we use\n        numpy to randperm the indices.\n        \"\"\"\n        assert len(gallery) >= num\n        if isinstance(gallery, list):\n            gallery = np.array(gallery)\n        cands = np.arange(len(gallery))\n        np.random.shuffle(cands)\n        rand_inds = cands[:num]\n        if not isinstance(gallery, np.ndarray):\n            rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)\n        return gallery[rand_inds]\n"
  },
  {
    "path": "knet/video/qdtrack/losses/multipos_cross_entropy_loss.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmdet.models import LOSSES, weight_reduce_loss\n\n\ndef multi_pos_cross_entropy(pred,\n                            label,\n                            weight=None,\n                            reduction='mean',\n                            avg_factor=None):\n    # element-wise losses\n    # pos_inds = (label == 1).float()\n    # neg_inds = (label == 0).float()\n    # exp_pos = (torch.exp(-1 * pred) * pos_inds).sum(dim=1)\n    # exp_neg = (torch.exp(pred.clamp(max=80)) * neg_inds).sum(dim=1)\n    # loss = torch.log(1 + exp_pos * exp_neg)\n\n    # a more numerical stable implementation.\n    pos_inds = (label == 1)\n    neg_inds = (label == 0)\n    pred_pos = pred * pos_inds.float()\n    pred_neg = pred * neg_inds.float()\n    # use -inf to mask out unwanted elements.\n    pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')\n    pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')\n\n    _pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1)\n    _neg_expand = pred_neg.repeat(1, pred.shape[1])\n\n    x = torch.nn.functional.pad((_neg_expand - _pos_expand), (0, 1), \"constant\", 0)\n    loss = torch.logsumexp(x, dim=1)\n\n\n    # apply weights and do the reduction\n    if weight is not None:\n        weight = weight.float()\n    loss = weight_reduce_loss(\n        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)\n\n    return loss\n\n\n@LOSSES.register_module()\nclass MultiPosCrossEntropyLoss(nn.Module):\n\n    def __init__(self, reduction='mean', loss_weight=1.0):\n        super(MultiPosCrossEntropyLoss, self).__init__()\n        self.reduction = reduction\n        self.loss_weight = loss_weight\n\n    def forward(self,\n                cls_score,\n                label,\n                weight=None,\n                avg_factor=None,\n                reduction_override=None,\n                **kwargs):\n        assert cls_score.size() == label.size()\n        assert reduction_override in (None, 'none', 'mean', 'sum')\n        reduction = (\n            reduction_override if reduction_override else self.reduction)\n        loss_cls = self.loss_weight * multi_pos_cross_entropy(\n            cls_score,\n            label,\n            weight,\n            reduction=reduction,\n            avg_factor=avg_factor,\n            **kwargs)\n        return loss_cls\n"
  },
  {
    "path": "knet/video/qdtrack/track/__init__.py",
    "content": "from .similarity import cal_similarity\nfrom .transforms import track2result, restore_result\n\n__all__ = ['cal_similarity', 'track2result', 'restore_result']\n"
  },
  {
    "path": "knet/video/qdtrack/track/similarity.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\ndef cal_similarity(key_embeds,\n                   ref_embeds,\n                   method='dot_product',\n                   temperature=-1):\n    assert method in ['dot_product', 'cosine']\n\n    if key_embeds.size(0) == 0 or ref_embeds.size(0) == 0:\n        return torch.zeros((key_embeds.size(0), ref_embeds.size(0)),\n                           device=key_embeds.device)\n\n    if method == 'cosine':\n        key_embeds = F.normalize(key_embeds, p=2, dim=1)\n        ref_embeds = F.normalize(ref_embeds, p=2, dim=1)\n        return torch.mm(key_embeds, ref_embeds.t())\n    elif method == 'dot_product':\n        if temperature > 0:\n            dists = cal_similarity(key_embeds, ref_embeds, method='cosine')\n            dists /= temperature\n            return dists\n        else:\n            return torch.mm(key_embeds, ref_embeds.t())\n"
  },
  {
    "path": "knet/video/qdtrack/track/transforms.py",
    "content": "import numpy as np\nimport torch\n\n\ndef track2result(bboxes, labels, ids, num_classes):\n    valid_inds = ids > -1\n    bboxes = bboxes[valid_inds]\n    labels = labels[valid_inds]\n    ids = ids[valid_inds]\n\n    if bboxes.shape[0] == 0:\n        return [np.zeros((0, 6), dtype=np.float32) for i in range(num_classes)]\n    else:\n        if isinstance(bboxes, torch.Tensor):\n            bboxes = bboxes.cpu().numpy()\n            labels = labels.cpu().numpy()\n            ids = ids.cpu().numpy()\n        return [\n            np.concatenate((ids[labels == i, None], bboxes[labels == i, :]),\n                           axis=1) for i in range(num_classes)\n        ]\n\n\ndef restore_result(result, return_ids=False):\n    labels = []\n    for i, bbox in enumerate(result):\n        labels.extend([i] * bbox.shape[0])\n    bboxes = np.concatenate(result, axis=0).astype(np.float32)\n    labels = np.array(labels, dtype=np.int64)\n    if return_ids:\n        ids = bboxes[:, 0].astype(np.int64)\n        bboxes = bboxes[:, 1:]\n        return bboxes, labels, ids\n    else:\n        return bboxes, labels\n"
  },
  {
    "path": "knet/video/qdtrack/trackers/__init__.py",
    "content": "from .quasi_dense_embed_tracker import QuasiDenseEmbedTracker\nfrom .tao_tracker import TaoTracker\n\n__all__ = ['QuasiDenseEmbedTracker', 'TaoTracker']"
  },
  {
    "path": "knet/video/qdtrack/trackers/quasi_dense_embed_tracker.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom mmdet.core import bbox_overlaps\n\nfrom ..builder import TRACKERS\n\n\n@TRACKERS.register_module()\nclass QuasiDenseEmbedTracker(object):\n\n    def __init__(self,\n                 init_score_thr=0.8,\n                 obj_score_thr=0.5,\n                 match_score_thr=0.5,\n                 memo_tracklet_frames=10,\n                 memo_backdrop_frames=1,\n                 memo_momentum=0.8,\n                 nms_conf_thr=0.5,\n                 nms_backdrop_iou_thr=0.3,\n                 nms_class_iou_thr=0.7,\n                 with_cats=True,\n                 match_metric='bisoftmax'):\n        assert 0 <= memo_momentum <= 1.0\n        assert memo_tracklet_frames >= 0\n        assert memo_backdrop_frames >= 0\n        self.init_score_thr = init_score_thr\n        self.obj_score_thr = obj_score_thr\n        self.match_score_thr = match_score_thr\n        self.memo_tracklet_frames = memo_tracklet_frames\n        self.memo_backdrop_frames = memo_backdrop_frames\n        self.memo_momentum = memo_momentum\n        self.nms_conf_thr = nms_conf_thr\n        self.nms_backdrop_iou_thr = nms_backdrop_iou_thr\n        self.nms_class_iou_thr = nms_class_iou_thr\n        self.with_cats = with_cats\n        assert match_metric in ['bisoftmax', 'softmax', 'cosine']\n        self.match_metric = match_metric\n\n        self.num_tracklets = 0\n        self.tracklets = dict()\n        self.backdrops = []\n\n    @property\n    def empty(self):\n        return False if self.tracklets else True\n\n    def update_memo(self, ids, bboxes, embeds, labels, frame_id):\n        tracklet_inds = ids > -1\n\n        # update memo\n        for id, bbox, embed, label in zip(ids[tracklet_inds],\n                                          bboxes[tracklet_inds],\n                                          embeds[tracklet_inds],\n                                          labels[tracklet_inds]):\n            id = int(id)\n            if id in self.tracklets.keys():\n                velocity = (bbox - self.tracklets[id]['bbox']) / (\n                    frame_id - self.tracklets[id]['last_frame'])\n                self.tracklets[id]['bbox'] = bbox\n                self.tracklets[id]['embed'] = (\n                    1 - self.memo_momentum\n                ) * self.tracklets[id]['embed'] + self.memo_momentum * embed\n                self.tracklets[id]['last_frame'] = frame_id\n                self.tracklets[id]['label'] = label\n                self.tracklets[id]['velocity'] = (\n                    self.tracklets[id]['velocity'] *\n                    self.tracklets[id]['acc_frame'] + velocity) / (\n                        self.tracklets[id]['acc_frame'] + 1)\n                self.tracklets[id]['acc_frame'] += 1\n            else:\n                self.tracklets[id] = dict(\n                    bbox=bbox,\n                    embed=embed,\n                    label=label,\n                    last_frame=frame_id,\n                    velocity=torch.zeros_like(bbox),\n                    acc_frame=0)\n\n        backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1)\n        ious = bbox_overlaps(bboxes[backdrop_inds, :-1], bboxes[:, :-1])\n        for i, ind in enumerate(backdrop_inds):\n            if (ious[i, :ind] > self.nms_backdrop_iou_thr).any():\n                backdrop_inds[i] = -1\n        backdrop_inds = backdrop_inds[backdrop_inds > -1]\n\n        self.backdrops.insert(\n            0,\n            dict(\n                bboxes=bboxes[backdrop_inds],\n                embeds=embeds[backdrop_inds],\n                labels=labels[backdrop_inds]))\n\n        # pop memo\n        invalid_ids = []\n        for k, v in self.tracklets.items():\n            if frame_id - v['last_frame'] >= self.memo_tracklet_frames:\n                invalid_ids.append(k)\n        for invalid_id in invalid_ids:\n            self.tracklets.pop(invalid_id)\n\n        if len(self.backdrops) > self.memo_backdrop_frames:\n            self.backdrops.pop()\n\n    @property\n    def memo(self):\n        memo_embeds = []\n        memo_ids = []\n        memo_bboxes = []\n        memo_labels = []\n        memo_vs = []\n        for k, v in self.tracklets.items():\n            memo_bboxes.append(v['bbox'][None, :])\n            memo_embeds.append(v['embed'][None, :])\n            memo_ids.append(k)\n            memo_labels.append(v['label'].view(1, 1))\n            memo_vs.append(v['velocity'][None, :])\n        memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1)\n\n        for backdrop in self.backdrops:\n            backdrop_ids = torch.full((1, backdrop['embeds'].size(0)),\n                                      -1,\n                                      dtype=torch.long)\n            backdrop_vs = torch.zeros_like(backdrop['bboxes'])\n            memo_bboxes.append(backdrop['bboxes'])\n            memo_embeds.append(backdrop['embeds'])\n            memo_ids = torch.cat([memo_ids, backdrop_ids], dim=1)\n            memo_labels.append(backdrop['labels'][:, None])\n            memo_vs.append(backdrop_vs)\n\n        memo_bboxes = torch.cat(memo_bboxes, dim=0)\n        memo_embeds = torch.cat(memo_embeds, dim=0)\n        memo_labels = torch.cat(memo_labels, dim=0).squeeze(1)\n        memo_vs = torch.cat(memo_vs, dim=0)\n        return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze(\n            0), memo_vs\n\n    def match(self, bboxes, labels, track_feats, frame_id, asso_tau=-1):\n\n        _, inds = bboxes[:, -1].sort(descending=True)\n        bboxes = bboxes[inds, :]\n        labels = labels[inds]\n        embeds = track_feats[inds, :]\n\n        # hack we do not consider the nms since we use\n        # # duplicate removal for potential backdrops and cross classes\n        valids = bboxes.new_ones((bboxes.size(0)))\n        ious = bbox_overlaps(bboxes[:, :-1], bboxes[:, :-1])\n        for i in range(1, bboxes.size(0)):\n            thr = self.nms_backdrop_iou_thr if bboxes[\n                i, -1] < self.obj_score_thr else self.nms_class_iou_thr\n            if (ious[i, :i] > thr).any():\n                valids[i] = 0\n        valids = valids == 1\n        bboxes = bboxes[valids, :]\n        labels = labels[valids]\n        embeds = embeds[valids, :]\n\n        # init ids container\n        ids = torch.full((bboxes.size(0), ), -1, dtype=torch.long)\n\n        # match if buffer is not empty\n        if bboxes.size(0) > 0 and not self.empty:\n            (memo_bboxes, memo_labels, memo_embeds, memo_ids,\n             memo_vs) = self.memo\n\n            if self.match_metric == 'bisoftmax':\n                feats = torch.mm(embeds, memo_embeds.t())\n                d2t_scores = feats.softmax(dim=1)\n                t2d_scores = feats.softmax(dim=0)\n                scores = (d2t_scores + t2d_scores) / 2\n            elif self.match_metric == 'softmax':\n                feats = torch.mm(embeds, memo_embeds.t())\n                scores = feats.softmax(dim=1)\n            elif self.match_metric == 'cosine':\n                scores = torch.mm(\n                    F.normalize(embeds, p=2, dim=1),\n                    F.normalize(memo_embeds, p=2, dim=1).t())\n            else:\n                raise NotImplementedError\n\n            if self.with_cats:\n                cat_same = labels.view(-1, 1) == memo_labels.view(1, -1)\n                scores *= cat_same.float().to(scores.device)\n\n            for i in range(bboxes.size(0)):\n                conf, memo_ind = torch.max(scores[i, :], dim=0)\n                id = memo_ids[memo_ind]\n                if conf > self.match_score_thr:\n                    if id > -1:\n                        if bboxes[i, -1] > self.obj_score_thr:\n                            ids[i] = id\n                            scores[:i, memo_ind] = 0\n                            scores[i + 1:, memo_ind] = 0\n                        else:\n                            if conf > self.nms_conf_thr:\n                                ids[i] = -2\n        new_inds = (ids == -1) & (bboxes[:, 4] > self.init_score_thr).cpu()\n        num_news = new_inds.sum()\n        ids[new_inds] = torch.arange(\n            self.num_tracklets,\n            self.num_tracklets + num_news,\n            dtype=torch.long)\n        self.num_tracklets += num_news\n\n        self.update_memo(ids, bboxes, embeds, labels, frame_id)\n\n        return bboxes, labels, ids\n"
  },
  {
    "path": "knet/video/qdtrack/trackers/tao_tracker.py",
    "content": "import os\nimport random\nfrom collections import defaultdict\n\nimport cv2\nimport mmcv\nimport numpy as np\nimport seaborn as sns\nimport torch\nfrom mmcv.image import imread, imwrite\nfrom mmcv.visualization import color_val, imshow\nfrom mmdet.core import bbox_overlaps\n\nfrom knet.video.qdtrack.track.similarity import cal_similarity\nfrom ..builder import TRACKERS\n\n\n@TRACKERS.register_module()\nclass TaoTracker(object):\n\n    def __init__(self,\n                 init_score_thr=0.0001,\n                 obj_score_thr=0.0001,\n                 match_score_thr=0.5,\n                 memo_frames=10,\n                 momentum_embed=0.8,\n                 momentum_obj_score=0.5,\n                 obj_score_diff_thr=1.0,\n                 distractor_nms_thr=0.3,\n                 distractor_score_thr=0.5,\n                 match_metric='bisoftmax',\n                 match_with_cosine=True):\n        self.init_score_thr = init_score_thr\n        self.obj_score_thr = obj_score_thr\n        self.match_score_thr = match_score_thr\n\n        self.memo_frames = memo_frames\n        self.momentum_embed = momentum_embed\n        self.momentum_obj_score = momentum_obj_score\n        self.obj_score_diff_thr = obj_score_diff_thr\n        self.distractor_nms_thr = distractor_nms_thr\n        self.distractor_score_thr = distractor_score_thr\n        assert match_metric in ['bisoftmax', 'cosine']\n        self.match_metric = match_metric\n        self.match_with_cosine = match_with_cosine\n\n        self.reset()\n\n    def reset(self):\n        self.num_tracklets = 0\n        self.tracklets = dict()\n        # for analysis\n        self.pred_tracks = defaultdict(lambda: defaultdict(list))\n        self.gt_tracks = defaultdict(lambda: defaultdict(list))\n\n    @property\n    def valid_ids(self):\n        valid_ids = []\n        for k, v in self.gt_tracks.items():\n            valid_ids.extend(v['ids'])\n        return list(set(valid_ids))\n\n    @property\n    def empty(self):\n        return False if self.tracklets else True\n\n    def update_memo(self, ids, bboxes, labels, embeds, frame_id):\n        tracklet_inds = ids > -1\n\n        # update memo\n        for id, bbox, embed, label in zip(ids[tracklet_inds],\n                                          bboxes[tracklet_inds],\n                                          embeds[tracklet_inds],\n                                          labels[tracklet_inds]):\n            id = int(id)\n            if id in self.tracklets:\n                self.tracklets[id]['bboxes'].append(bbox)\n                self.tracklets[id]['labels'].append(label)\n                self.tracklets[id]['embeds'] = (\n                                                       1 - self.momentum_embed\n                                               ) * self.tracklets[id]['embeds'] + self.momentum_embed * embed\n                self.tracklets[id]['frame_ids'].append(frame_id)\n            else:\n                self.tracklets[id] = dict(\n                    bboxes=[bbox],\n                    labels=[label],\n                    embeds=embed,\n                    frame_ids=[frame_id])\n\n        # pop memo\n        invalid_ids = []\n        for k, v in self.tracklets.items():\n            if frame_id - v['frame_ids'][-1] >= self.memo_frames:\n                invalid_ids.append(k)\n        for invalid_id in invalid_ids:\n            self.tracklets.pop(invalid_id)\n\n    @property\n    def memo(self):\n        memo_ids = []\n        memo_bboxes = []\n        memo_labels = []\n        memo_embeds = []\n        for k, v in self.tracklets.items():\n            memo_ids.append(k)\n            memo_bboxes.append(v['bboxes'][-1][None, :])\n            memo_labels.append(v['labels'][-1].view(1, 1))\n            memo_embeds.append(v['embeds'][None, :])\n        memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1)\n\n        memo_bboxes = torch.cat(memo_bboxes, dim=0)\n        memo_embeds = torch.cat(memo_embeds, dim=0)\n        memo_labels = torch.cat(memo_labels, dim=0).squeeze(1)\n        return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze(0)\n\n    def init_tracklets(self, ids, obj_scores):\n        new_objs = (ids == -1) & (obj_scores > self.init_score_thr).cpu()\n        num_new_objs = new_objs.sum()\n        ids[new_objs] = torch.arange(\n            self.num_tracklets,\n            self.num_tracklets + num_new_objs,\n            dtype=torch.long)\n        self.num_tracklets += num_new_objs\n        return ids\n\n    def match(self,\n              bboxes,\n              labels,\n              track_feats,\n              frame_id,\n              temperature=-1,\n              **kwargs):\n        if track_feats is None:\n            ids = torch.full((bboxes.size(0), ), -1, dtype=torch.long)\n            return bboxes, labels, ids\n\n        # all objects is valid here\n        valid_inds = labels > -1\n        # nms\n        low_inds = torch.nonzero(\n            bboxes[:, -1] < self.distractor_score_thr,\n            as_tuple=False).squeeze(1)\n        cat_same = labels[low_inds].view(-1, 1) == labels.view(1, -1)\n        ious = bbox_overlaps(bboxes[low_inds, :-1], bboxes[:, :-1])\n        ious *= cat_same.to(ious.device)\n        for i, ind in enumerate(low_inds):\n            if (ious[i, :ind] > self.distractor_nms_thr).any():\n                valid_inds[ind] = False\n        bboxes = bboxes[valid_inds]\n        labels = labels[valid_inds]\n        embeds = track_feats[valid_inds]\n\n        # match if buffer is not empty\n        if bboxes.size(0) > 0 and not self.empty:\n            memo_bboxes, memo_labels, memo_embeds, memo_ids = self.memo\n\n            if self.match_metric == 'bisoftmax':\n                sims = cal_similarity(\n                    embeds,\n                    memo_embeds,\n                    method='dot_product',\n                    temperature=temperature)\n                cat_same = labels.view(-1, 1) == memo_labels.view(1, -1)\n                exps = torch.exp(sims) * cat_same.to(sims.device)\n                d2t_scores = exps / (exps.sum(dim=1).view(-1, 1) + 1e-6)\n                t2d_scores = exps / (exps.sum(dim=0).view(1, -1) + 1e-6)\n                cos_scores = cal_similarity(\n                    embeds, memo_embeds, method='cosine')\n                cos_scores *= cat_same.to(cos_scores.device)\n                scores = (d2t_scores + t2d_scores) / 2\n                if self.match_with_cosine:\n                    scores = (scores + cos_scores) / 2\n            elif self.match_metric == 'cosine':\n                cos_scores = cal_similarity(\n                    embeds, memo_embeds, method='cosine')\n                cat_same = labels.view(-1, 1) == memo_labels.view(1, -1)\n                scores = cos_scores * cat_same.float().to(cos_scores.device)\n            else:\n                raise NotImplementedError()\n            if 'metas' in kwargs:\n                raw_scores = scores.clone()\n\n            obj_score_diffs = torch.abs(\n                bboxes[:, -1].view(-1, 1).expand_as(scores) -\n                memo_bboxes[:, -1].view(1, -1).expand_as(scores))\n\n            num_objs = bboxes.size(0)\n            ids = torch.full((num_objs, ), -1, dtype=torch.long)\n            for i in range(num_objs):\n                if bboxes[i, -1] < self.obj_score_thr:\n                    continue\n                conf, memo_ind = torch.max(scores[i, :], dim=0)\n                obj_score_diff = obj_score_diffs[i, memo_ind]\n                if (conf > self.match_score_thr) and (obj_score_diff <\n                                                      self.obj_score_diff_thr):\n                    ids[i] = memo_ids[memo_ind]\n                    scores[:i, memo_ind] = 0\n                    scores[i + 1:, memo_ind] = 0\n                    m = self.momentum_obj_score\n                    bboxes[i, -1] = m * bboxes[i, -1] + (\n                            1 - m) * memo_bboxes[memo_ind, -1]\n        else:\n            ids = torch.full((bboxes.size(0), ), -1, dtype=torch.long)\n        # init tracklets\n        ids = self.init_tracklets(ids, bboxes[:, -1])\n        self.update_memo(ids, bboxes, labels, embeds, frame_id)\n\n        # ----------------\n        if 'metas' in kwargs and kwargs['metas'].analyze:\n            metas = kwargs['metas']\n            gt_bboxes, gt_labels, gt_ids = [\n                metas['bboxes'], metas['labels'], metas['instance_ids']\n            ]\n            gt_bboxes = torch.cat(\n                (gt_bboxes, torch.zeros(gt_bboxes.size(0), 1)), dim=1)\n\n            if bboxes.size(0) == 0 or gt_bboxes.size(0) == 0:\n                return bboxes, labels, ids\n\n            fns = torch.ones(gt_bboxes.size(0), dtype=torch.long)\n            fps = torch.ones(bboxes.size(0), dtype=torch.long)\n            sw_fps = torch.zeros(bboxes.size(0), dtype=torch.long)\n            idsw = torch.zeros(bboxes.size(0), dtype=torch.long)\n\n            ious = bbox_overlaps(bboxes[:, :4], gt_bboxes[:, :4])\n            same_cat = labels.view(-1, 1) == gt_labels.view(1, -1)\n            ious *= same_cat.float().to(ious.device)\n\n            gt_inds = torch.full(ids.size(), -1, dtype=torch.long)\n            for i, bbox in enumerate(bboxes):\n                max_iou, j = ious[i].max(dim=0)\n                if max_iou > 0.5:\n                    fps[i], fns[j] = 0, 0\n                    gt_inds[i] = j\n                    ious[:, j] = -1\n\n                    gt_id = int(gt_ids[j])\n                    pred_id = int(ids[i])\n                    if len(self.gt_tracks[gt_id]['ids']) > 0:\n                        if pred_id != self.gt_tracks[gt_id]['ids'][-1]:\n                            idsw[i] = 1\n                    else:\n                        if pred_id in self.pred_tracks:\n                            idsw[i] = 1\n                    self.gt_tracks[gt_id]['scores'].append(\n                        float(f'{bbox[-1]:.3f}'))\n                    self.gt_tracks[gt_id]['ids'].append(pred_id)\n                    self.gt_tracks[gt_id]['frame_ids'].append(\n                        metas.img_info['frame_id'])\n\n            for i, id in enumerate(ids):\n                id = int(id)\n\n                self.pred_tracks[id]['scores'].append(\n                    float(f'{bboxes[i, -1]:.3f}'))\n                if metas.img_info['frame_id'] > 0:\n                    memo_ind = torch.nonzero(\n                        memo_ids == id, as_tuple=False).squeeze(1)\n                else:\n                    memo_ind = []\n                if len(memo_ind) > 0:\n                    self.pred_tracks[id]['match_scores'].append(\n                        float(f'{raw_scores[i, memo_ind[0]]:.3f}'))\n                else:\n                    self.pred_tracks[id]['match_scores'].append(-1)\n                if gt_inds[i] == -1:\n                    self.pred_tracks[id]['ids'].append(-1)\n                else:\n                    self.pred_tracks[id]['ids'].append(int(gt_ids[gt_inds[i]]))\n                self.pred_tracks[id]['frame_ids'].append(\n                    metas.img_info['frame_id'])\n\n                if fps[i]:\n                    if id in self.valid_ids:\n                        sw_fps[i] = 1\n                    continue\n\n            fp_inds = sw_fps == 1  # red\n            fn_inds = fns == 1  # yellow\n            idsw_inds = idsw == 1  # cyan\n            tp_inds = fps == 0  # green\n            tp_inds[idsw_inds] = 0\n\n            os.makedirs(metas.out_file.rsplit('/', 1)[0], exist_ok=True)\n            img = metas.img_name\n            # black\n            if idsw_inds.any():\n                sw_ids = ids[idsw_inds]\n                memo_inds = (memo_ids.view(-1, 1) == sw_ids.view(\n                    1, -1)).sum(dim=1) > 0\n                img = imshow_tracklets(\n                    img,\n                    memo_bboxes[memo_inds].numpy(),\n                    memo_labels[memo_inds].numpy(),\n                    memo_ids[memo_inds].numpy(),\n                    color='magenta',\n                    show=False)\n            img = imshow_tracklets(\n                img,\n                bboxes[tp_inds].numpy(),\n                labels[tp_inds].numpy(),\n                ids[tp_inds].numpy(),\n                color='green',\n                show=False)\n            img = imshow_tracklets(\n                img,\n                bboxes[fp_inds].numpy(),\n                labels[fp_inds].numpy(),\n                ids[fp_inds].numpy(),\n                color='red',\n                show=False)\n            img = imshow_tracklets(\n                img,\n                bboxes=gt_bboxes[fn_inds, :].numpy(),\n                labels=gt_labels[fn_inds].numpy(),\n                color='yellow',\n                show=False)\n            img = imshow_tracklets(\n                img,\n                bboxes[idsw_inds].numpy(),\n                labels[idsw_inds].numpy(),\n                ids[idsw_inds].numpy(),\n                color='cyan',\n                show=False,\n                out_file=metas.out_file)\n\n        return bboxes, labels, ids\n\n\ndef random_color(seed):\n    random.seed(seed)\n    colors = sns.color_palette()\n    color = random.choice(colors)\n    return color\n\n\ndef imshow_tracklets(img,\n                     bboxes,\n                     labels=None,\n                     ids=None,\n                     thickness=2,\n                     font_scale=0.4,\n                     show=False,\n                     win_name='',\n                     color=None,\n                     out_file=None):\n    assert bboxes.ndim == 2\n    assert labels.ndim == 1\n    assert bboxes.shape[0] == labels.shape[0]\n    # assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5\n    if isinstance(img, str):\n        img = imread(img)\n    i = 0\n    if bboxes.shape[0] == 0:\n        if out_file is not None:\n            imwrite(img, out_file)\n        return img\n    if isinstance(bboxes, torch.Tensor):\n        bboxes = bboxes.numpy()\n        labels = labels.numpy()\n        ids = ids.numpy()\n    for bbox, label in zip(bboxes, labels):\n        x1, y1, x2, y2, _ = bbox.astype(np.int32)\n        if ids is not None:\n            if color is None:\n                bbox_color = random_color(ids[i])\n                bbox_color = [int(255 * _c) for _c in bbox_color][::-1]\n            else:\n                bbox_color = mmcv.color_val(color)\n            img[y1:y1 + 12, x1:x1 + 20, :] = bbox_color\n            cv2.putText(\n                img,\n                str(ids[i]), (x1, y1 + 10),\n                cv2.FONT_HERSHEY_COMPLEX,\n                font_scale,\n                color=color_val('black'))\n        else:\n            if color is None:\n                bbox_color = color_val('green')\n            else:\n                bbox_color = mmcv.color_val(color)\n\n        cv2.rectangle(img, (x1, y1), (x2, y2), bbox_color, thickness=thickness)\n\n        if bbox[-1] < 0:\n            bbox[-1] = np.nan\n        # label_text = '{:.02f}'.format(bbox[-1])\n        # img[y1 - 12:y1, x1:x1 + 30, :] = bbox_color\n        # cv2.putText(\n        #     img,\n        #     label_text, (x1, y1 - 2),\n        #     cv2.FONT_HERSHEY_COMPLEX,\n        #     font_scale,\n        #     color=color_val('black'))\n\n        i += 1\n\n    if show:\n        imshow(img, win_name)\n    if out_file is not None:\n        imwrite(img, out_file)\n\n    return img\n"
  },
  {
    "path": "knet/video/track_heads.py",
    "content": "\"\"\"\n    This file implements several tracking heads\n\"\"\"\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, normal_init\nfrom mmdet.models.builder import HEADS, build_head, build_loss, build_roi_extractor\nfrom mmdet.models.losses import accuracy\nfrom mmdet.core import multi_apply, bbox2roi\nfrom knet.video.qdtrack.track import cal_similarity\nfrom unitrack.utils.mask import mask2box, batch_mask2boxlist, bboxlist2roi\n\n\n@HEADS.register_module()\nclass QueryTrackHead(nn.Module):\n    \"\"\"Tracking head, predict tracking features and match with reference objects\n       Use dynamic option to deal with different number of objects in different\n       images. A non-match entry is added to the reference objects with all-zero\n       features. Object matched with the non-match entry is considered as a new\n       object.\n    \"\"\"\n\n    def __init__(self,\n                 num_fcs=2,\n                 in_channels=256,\n                 fc_out_channels=1024,\n                 match_coeff=None,\n                 bbox_dummy_iou=0,\n                 dynamic=True,\n                 loss_match=dict(\n                     type='CrossEntropyLoss',\n                     use_sigmoid=False,\n                     loss_weight=1.0)):\n\n        super(QueryTrackHead, self).__init__()\n        self.in_channels = in_channels\n        self.match_coeff = match_coeff\n        self.bbox_dummy_iou = bbox_dummy_iou\n        self.num_fcs = num_fcs\n        self.fcs = nn.ModuleList()\n        for i in range(num_fcs):\n            out_channels = (in_channels\n                           if i < num_fcs - 1  else fc_out_channels)\n            fc = nn.Linear(in_channels, out_channels)\n            self.fcs.append(fc)\n        self.relu = nn.ReLU(inplace=True)\n        self.debug_imgs = None\n        self.dynamic = dynamic\n        assert self.dynamic == True, \"Naive tracking embedding head must be dynamic\"\n        #### modification\n        self.loss_match = build_loss(loss_match)\n\n    def init_weights(self):\n        for fc in self.fcs:\n            nn.init.normal_(fc.weight, 0, 0.01)\n            nn.init.constant_(fc.bias, 0)\n\n    def compute_comp_scores(self, match_ll, bbox_scores, bbox_ious, label_delta, add_bbox_dummy=False):\n        # compute comprehensive matching score based on matchig likelihood,\n        # bbox confidence, and ious\n        if add_bbox_dummy:\n            bbox_iou_dummy = torch.ones(bbox_ious.size(0), 1).to(torch.cuda.current_device()) * self.bbox_dummy_iou\n            bbox_ious = torch.cat((bbox_iou_dummy, bbox_ious), dim=1)\n            label_dummy = torch.ones(bbox_ious.size(0), 1).to(torch.cuda.current_device())\n            label_delta = torch.cat((label_dummy, label_delta), dim=1)\n        if self.match_coeff is None:\n            return match_ll\n        else:\n            # match coeff needs to be length of 3\n            assert (len(self.match_coeff) == 3)\n            return (match_ll +\n                    self.match_coeff[0] * torch.log(bbox_scores) +\n                    self.match_coeff[1] * bbox_ious +\n                    self.match_coeff[2] * label_delta)\n\n    def forward(self, x, ref_x, x_n, ref_x_n):\n        # x and ref_x are the grouped bbox features of current and reference frame\n        # x_n are the numbers of proposals in the current images in the mini-batch,\n        # ref_x_n are the numbers of ground truth bboxes in the reference images.\n        # here we compute a correlation matrix of x and ref_x\n        # we also add a all 0 column denote no matching\n        assert len(x_n) == len(ref_x_n)  # ==> the batch size should be the same.\n        b, N, d = x.size()\n        x = x.reshape(b*N, d)\n        ref_x = ref_x.reshape(b*N, d)\n        for idx, fc in enumerate(self.fcs):\n            x = fc(x)\n            ref_x = fc(ref_x)\n            if idx < len(self.fcs) - 1:\n                x = self.relu(x)\n                ref_x = self.relu(ref_x)\n        n = len(x_n)\n        x_split = torch.split(x, x_n, dim=0)\n        ref_x_split = torch.split(ref_x, ref_x_n, dim=0)\n        prods = []\n        for i in range(n):\n            prod = torch.mm(x_split[i], torch.transpose(ref_x_split[i], 0, 1))\n            prods.append(prod)\n        if self.dynamic:\n            match_score = []\n            for prod in prods:\n                m = prod.size(0)\n                dummy = torch.zeros(m, 1).to(torch.cuda.current_device())\n\n                prod_ext = torch.cat([dummy, prod], dim=1)\n                match_score.append(prod_ext)\n\n        return match_score\n\n    def loss(self,\n             match_score,\n             sampling_results):\n        losses = dict()\n        n = len(match_score)\n        x_n = [s.size(0) for s in match_score]\n        ids, id_weights = self.get_targets(sampling_results)\n        ids = torch.split(ids, x_n, dim=0)\n        id_weights = torch.split(id_weights, x_n, dim=0)\n        loss_match = 0.0\n        match_acc = 0.0\n        n_total = 0\n\n        for score, cur_ids, cur_weights in zip(match_score, ids, id_weights):\n            valid_idx = torch.nonzero(cur_weights).squeeze()\n            if len(valid_idx.size()) == 0:\n                continue\n            n_valid = valid_idx.size(0)\n            n_total += n_valid\n            loss_match_per_batch = self.loss_match(score, cur_ids, cur_weights)\n            match_acc += accuracy(\n                torch.index_select(score, 0, valid_idx),\n                torch.index_select(cur_ids, 0, valid_idx)) * n_valid\n            loss_match += loss_match_per_batch\n        if loss_match == 0.0:\n            losses['loss_match'] = ids[0].sum() * 0\n        else:\n            losses['loss_match'] = loss_match / n\n        return losses\n\n    def get_targets(self,\n                    sampling_results,\n                    concat=True,\n                    ):\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_pid_list = [res.pos_gt_pids for res in sampling_results]\n        ids, id_weights = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_pid_list)\n        if concat:\n            ids = torch.cat(ids, 0)\n            id_weights = torch.cat(id_weights, 0)\n\n        return ids, id_weights\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask, pos_gt_pid_list):\n\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n\n        ids = pos_mask.new_zeros((num_samples,), dtype=torch.long)\n        ids_weights = pos_mask.new_zeros((num_samples,))\n        if num_pos > 0:\n            ids[pos_inds] = pos_gt_pid_list\n            ids_weights[pos_inds] = 1.0\n\n        if num_neg > 0:\n            ids_weights[neg_inds] = 0.0\n\n        return ids, ids_weights\n\n\n@HEADS.register_module()\nclass TrackHeadWithROIAlign(nn.Module):\n    \"\"\"Tracking head, predict tracking features and match with reference objects\n       Use dynamic option to deal with different number of objects in different\n       images. A non-match entry is added to the reference objects with all-zero\n       features. Object matched with the non-match entry is considered as a new\n       object.\n    \"\"\"\n\n    def __init__(self,\n                 num_fcs=2,\n                 in_channels=256,\n                 fc_out_channels=1024,\n                 match_coeff=None,\n                 bbox_dummy_iou=0,\n                 dynamic=True,\n                 bbox_roi_extractor=dict(\n                     type='SingleRoIExtractor',\n                     roi_layer=dict(\n                         type='RoIAlign', output_size=7, sampling_ratio=2),\n                     out_channels=256,\n                     featmap_strides=[4, 8, 16, 32]),\n                 loss_match=dict(\n                     type='CrossEntropyLoss',\n                     use_sigmoid=False,\n                     loss_weight=1.0)):\n\n        super(TrackHeadWithROIAlign, self).__init__()\n        assert bbox_roi_extractor is not None\n        self.in_channels = in_channels\n        self.match_coeff = match_coeff\n        self.bbox_dummy_iou = bbox_dummy_iou\n        self.num_fcs = num_fcs\n        self.fcs = nn.ModuleList()\n\n        for i in range(num_fcs):\n            out_channels = (in_channels\n                           if i < num_fcs - 1  else fc_out_channels)\n            fc = nn.Linear(in_channels, out_channels)\n            self.fcs.append(fc)\n        self.relu = nn.ReLU(inplace=True)\n        self.debug_imgs = None\n        self.dynamic = dynamic\n        assert self.dynamic == True, \"Naive tracking embedding head must be dynamic\"\n\n        self.bbox_roi_extractor = build_roi_extractor(\n                bbox_roi_extractor)\n        #### modification\n        self.loss_match = build_loss(loss_match)\n\n    def init_weights(self):\n        for fc in self.fcs:\n            nn.init.normal_(fc.weight, 0, 0.01)\n            nn.init.constant_(fc.bias, 0)\n\n    def compute_comp_scores(self, match_ll, bbox_scores, bbox_ious, label_delta, add_bbox_dummy=False):\n        # compute comprehensive matching score based on matchig likelihood,\n        # bbox confidence, and ious\n        if add_bbox_dummy:\n            bbox_iou_dummy = torch.ones(bbox_ious.size(0), 1).to(torch.cuda.current_device()) * self.bbox_dummy_iou\n            bbox_ious = torch.cat((bbox_iou_dummy, bbox_ious), dim=1)\n            label_dummy = torch.ones(bbox_ious.size(0), 1).to(torch.cuda.current_device())\n            label_delta = torch.cat((label_dummy, label_delta), dim=1)\n        if self.match_coeff is None:\n            return match_ll\n        else:\n            # match coeff needs to be length of 3\n            assert (len(self.match_coeff) == 3)\n            return (match_ll +\n                    self.match_coeff[0] * torch.log(bbox_scores) +\n                    self.match_coeff[1] * bbox_ious +\n                    self.match_coeff[2] * label_delta)\n\n    def forward(self, x, ref_x, mask_pred, ref_mask_pred, x_n, ref_x_n):\n        \"\"\"\n        Args:\n            x: backbone feature of current frame\n            ref_x: backbone feature of reference frame\n            mask_pred: mask prediction of current frame\n            ref_mask_pred: reference mask prediction\n            x_n: number of proposal\n            ref_x_n:  number of proposal in ref frame\n\n        Returns:\n\n        \"\"\"\n        # print(\"mask shape \",mask_pred.shape)\n        bbox_pred = batch_mask2boxlist(mask_pred)\n        ref_bbox_pred = batch_mask2boxlist(ref_mask_pred)\n\n        # rois = bboxlist2roi(bbox_pred)\n        # ref_rois = bboxlist2roi(ref_bbox_pred)\n\n        x = self.bbox_roi_extractor(\n            x[:self.bbox_roi_extractor.num_inputs], rois)\n\n        ref_x = self.bbox_roi_extractor(\n                ref_x[:self.bbox_roi_extractor.num_inputs], ref_rois)\n        # x and ref_x are the grouped bbox features of current and reference frame\n        # x_n are the numbers of proposals in the current images in the mini-batch,\n        # ref_x_n are the numbers of ground truth bboxes in the reference images.\n        # here we compute a correlation matrix of x and ref_x\n        # we also add a all 0 column denote no matching\n\n        b, N, d = x.size()\n        x = x.reshape(b*N, d)\n        ref_x = ref_x.reshape(b*N, d)\n        for idx, fc in enumerate(self.fcs):\n            x = fc(x)\n            ref_x = fc(ref_x)\n            if idx < len(self.fcs) - 1:\n                x = self.relu(x)\n                ref_x = self.relu(ref_x)\n        n = len(x_n)\n        x_split = torch.split(x, x_n, dim=0)\n        ref_x_split = torch.split(ref_x, ref_x_n, dim=0)\n        prods = []\n        for i in range(n):\n            prod = torch.mm(x_split[i], torch.transpose(ref_x_split[i], 0, 1))\n            prods.append(prod)\n        if self.dynamic:\n            match_score = []\n            for prod in prods:\n                m = prod.size(0)\n                dummy = torch.zeros(m, 1).to(torch.cuda.current_device())\n\n                prod_ext = torch.cat([dummy, prod], dim=1)\n                match_score.append(prod_ext)\n\n        return match_score\n\n    def loss(self,\n             match_score,\n             sampling_results):\n        losses = dict()\n        n = len(match_score)\n        x_n = [s.size(0) for s in match_score]\n        ids, id_weights = self.get_targets(sampling_results)\n        ids = torch.split(ids, x_n, dim=0)\n        id_weights = torch.split(id_weights, x_n, dim=0)\n        loss_match = torch.zeros(0).to(torch.cuda.current_device())\n        match_acc = 0.\n        n_total = 0\n\n        for score, cur_ids, cur_weights in zip(match_score, ids, id_weights):\n            valid_idx = torch.nonzero(cur_weights).squeeze()\n            if len(valid_idx.size()) == 0:\n                continue\n            n_valid = valid_idx.size(0)\n            n_total += n_valid\n            loss_match += self.loss_match(\n                score, cur_ids, cur_weights)\n            match_acc += accuracy(\n                torch.index_select(score, 0, valid_idx),\n                torch.index_select(cur_ids, 0, valid_idx)) * n_valid\n        losses['loss_match'] = loss_match / n\n        if n_total > 0:\n            losses['match_acc'] = match_acc / n_total\n        return losses\n\n    def get_targets(self,\n                    sampling_results,\n                    concat=True,\n                    ):\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_pid_list = [res.pos_gt_pids for res in sampling_results]\n        ids, id_weights = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_pid_list)\n        if concat:\n            ids = torch.cat(ids, 0)\n            id_weights = torch.cat(id_weights, 0)\n\n        return ids, id_weights\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask, pos_gt_pid_list):\n\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n\n        ids = pos_mask.new_zeros((num_samples,), dtype=torch.long)\n        ids_weights = pos_mask.new_zeros((num_samples,))\n        if num_pos > 0:\n            ids[pos_inds] = pos_gt_pid_list\n            ids_weights[pos_inds] = 1.0\n\n        if num_neg > 0:\n            ids_weights[neg_inds] = 0.0\n\n        return ids, ids_weights\n\n\n@HEADS.register_module()\nclass QuasiDenseMaskEmbedHead(nn.Module):\n\n    def __init__(self,\n                 num_convs=4,\n                 num_fcs=1,\n                 roi_feat_size=7,\n                 in_channels=256,\n                 conv_out_channels=256,\n                 fc_out_channels=1024,\n                 embed_channels=256,\n                 conv_cfg=None,\n                 norm_cfg=None,\n                 softmax_temp=-1,\n                 loss_track=dict(\n                     type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n                 loss_track_aux=dict(\n                     type='L2Loss',\n                     sample_ratio=3,\n                     margin=0.3,\n                     loss_weight=1.0,\n                     hard_mining=True)):\n        super(QuasiDenseMaskEmbedHead, self).__init__()\n        self.num_convs = num_convs\n        self.num_fcs = num_fcs\n        self.roi_feat_size = roi_feat_size\n        self.in_channels = in_channels\n        self.conv_out_channels = conv_out_channels\n        self.fc_out_channels = fc_out_channels\n        self.embed_channels = embed_channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.relu = nn.ReLU(inplace=True)\n        self.convs, self.fcs, last_layer_dim = self._add_conv_fc_branch(\n            self.num_convs, self.num_fcs, self.in_channels)\n        self.fc_embed = nn.Linear(last_layer_dim, embed_channels)\n\n        self.softmax_temp = softmax_temp\n        self.loss_track = build_loss(loss_track)\n        if loss_track_aux is not None:\n            self.loss_track_aux = build_loss(loss_track_aux)\n        else:\n            self.loss_track_aux = None\n\n    def _add_conv_fc_branch(self, num_convs, num_fcs, in_channels):\n        last_layer_dim = in_channels\n        # add branch specific conv layers\n        convs = nn.ModuleList()\n        if num_convs > 0:\n            for i in range(num_convs):\n                conv_in_channels = (\n                    last_layer_dim if i == 0 else self.conv_out_channels)\n                convs.append(\n                    ConvModule(\n                        conv_in_channels,\n                        self.conv_out_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg))\n            last_layer_dim = self.conv_out_channels\n        # add branch specific fc layers\n        fcs = nn.ModuleList()\n        if num_fcs > 0:\n            last_layer_dim *= (self.roi_feat_size * self.roi_feat_size)\n            for i in range(num_fcs):\n                fc_in_channels = (\n                    last_layer_dim if i == 0 else self.fc_out_channels)\n                fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))\n            last_layer_dim = self.fc_out_channels\n        return convs, fcs, last_layer_dim\n\n    def init_weights(self):\n        for m in self.fcs:\n            if isinstance(m, nn.Linear):\n                nn.init.xavier_uniform_(m.weight)\n                nn.init.constant_(m.bias, 0)\n        nn.init.normal_(self.fc_embed.weight, 0, 0.01)\n        nn.init.constant_(self.fc_embed.bias, 0)\n\n    def forward(self, x):\n        if self.num_convs > 0:\n            for i, conv in enumerate(self.convs):\n                x = conv(x)\n        x = x.view(x.size(0), -1)\n\n        if self.num_fcs > 0:\n            for i, fc in enumerate(self.fcs):\n                x = self.relu(fc(x))\n        x = self.fc_embed(x)\n        return x\n\n    def get_track_targets(self, gt_match_indices, key_sampling_results,\n                          ref_sampling_results):\n        track_targets = []\n        track_weights = []\n        for _gt_match_indices, key_res, ref_res in zip(gt_match_indices,\n                                                       key_sampling_results,\n                                                       ref_sampling_results):\n            targets = _gt_match_indices.new_zeros(\n                (key_res.pos_masks.size(0), ref_res.pos_masks.size(0)),\n                dtype=torch.int)\n            _match_indices = _gt_match_indices[key_res.pos_assigned_gt_inds]\n            pos2pos = (_match_indices.view(\n                -1, 1) == ref_res.pos_assigned_gt_inds.view(1, -1)).int()\n            targets[:, :pos2pos.size(1)] = pos2pos\n            weights = (targets.sum(dim=1) > 0).float()\n            track_targets.append(targets)\n            track_weights.append(weights)\n        return track_targets, track_weights\n\n    def match(self, key_embeds, ref_embeds, key_sampling_results,\n              ref_sampling_results):\n\n        num_key_rois = [res.pos_masks.size(0) for res in key_sampling_results]\n        key_embeds = torch.split(key_embeds, num_key_rois)\n        num_ref_rois = [res.pos_masks.size(0) for res in ref_sampling_results]\n        ref_embeds = torch.split(ref_embeds, num_ref_rois)\n\n        dists, cos_dists = [], []\n        for key_embed, ref_embed in zip(key_embeds, ref_embeds):\n            dist = cal_similarity(\n                key_embed,\n                ref_embed,\n                method='dot_product',\n                temperature=self.softmax_temp)\n            dists.append(dist)\n            if self.loss_track_aux is not None:\n                cos_dist = cal_similarity(\n                    key_embed, ref_embed, method='cosine')\n                cos_dists.append(cos_dist)\n            else:\n                cos_dists.append(None)\n        return dists, cos_dists\n\n    def loss(self, dists, cos_dists, targets, weights):\n        losses = dict()\n\n        loss_track = 0.\n        loss_track_aux = 0.\n        for _dists, _cos_dists, _targets, _weights in zip(\n                dists, cos_dists, targets, weights):\n            loss_track += self.loss_track(\n                _dists, _targets, _weights, avg_factor=_weights.sum())\n            if self.loss_track_aux is not None:\n                loss_track_aux += self.loss_track_aux(_cos_dists, _targets)\n        losses['loss_track'] = loss_track / len(dists)\n\n        if self.loss_track_aux is not None:\n            losses['loss_track_aux'] = loss_track_aux / len(dists)\n\n        return losses\n\n    @staticmethod\n    def random_choice(gallery, num):\n        \"\"\"Random select some elements from the gallery.\n\n        It seems that Pytorch's implementation is slower than numpy so we use\n        numpy to randperm the indices.\n        \"\"\"\n        assert len(gallery) >= num\n        if isinstance(gallery, list):\n            gallery = np.array(gallery)\n        cands = np.arange(len(gallery))\n        np.random.shuffle(cands)\n        rand_inds = cands[:num]\n        if not isinstance(gallery, np.ndarray):\n            rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)\n        return gallery[rand_inds]\n\n\n@HEADS.register_module()\nclass QuasiDenseMaskEmbedHeadGTMask(nn.Module):\n\n    def __init__(self,\n                 num_convs=4,\n                 num_fcs=1,\n                 roi_feat_size=7,\n                 in_channels=256,\n                 conv_out_channels=256,\n                 fc_out_channels=1024,\n                 embed_channels=256,\n                 conv_cfg=None,\n                 norm_cfg=None,\n                 softmax_temp=-1,\n                 loss_track=dict(\n                     type='MultiPosCrossEntropyLoss', loss_weight=0.25),\n                 loss_track_aux=dict(\n                     type='L2Loss',\n                     sample_ratio=3,\n                     margin=0.3,\n                     loss_weight=1.0,\n                     hard_mining=True)):\n        super(QuasiDenseMaskEmbedHeadGTMask, self).__init__()\n        self.num_convs = num_convs\n        self.num_fcs = num_fcs\n        self.roi_feat_size = roi_feat_size\n        self.in_channels = in_channels\n        self.conv_out_channels = conv_out_channels\n        self.fc_out_channels = fc_out_channels\n        self.embed_channels = embed_channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.relu = nn.ReLU(inplace=True)\n        self.convs, self.fcs, last_layer_dim = self._add_conv_fc_branch(\n            self.num_convs, self.num_fcs, self.in_channels)\n        self.fc_embed = nn.Linear(last_layer_dim, embed_channels)\n\n        self.softmax_temp = softmax_temp\n        self.loss_track = build_loss(loss_track)\n        if loss_track_aux is not None:\n            self.loss_track_aux = build_loss(loss_track_aux)\n        else:\n            self.loss_track_aux = None\n\n    def _add_conv_fc_branch(self, num_convs, num_fcs, in_channels):\n        last_layer_dim = in_channels\n        # add branch specific conv layers\n        convs = nn.ModuleList()\n        if num_convs > 0:\n            for i in range(num_convs):\n                conv_in_channels = (\n                    last_layer_dim if i == 0 else self.conv_out_channels)\n                convs.append(\n                    ConvModule(\n                        conv_in_channels,\n                        self.conv_out_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg))\n            last_layer_dim = self.conv_out_channels\n        # add branch specific fc layers\n        fcs = nn.ModuleList()\n        if num_fcs > 0:\n            last_layer_dim *= (self.roi_feat_size * self.roi_feat_size)\n            for i in range(num_fcs):\n                fc_in_channels = (\n                    last_layer_dim if i == 0 else self.fc_out_channels)\n                fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))\n            last_layer_dim = self.fc_out_channels\n        return convs, fcs, last_layer_dim\n\n    def init_weights(self):\n        for m in self.fcs:\n            if isinstance(m, nn.Linear):\n                nn.init.xavier_uniform_(m.weight)\n                nn.init.constant_(m.bias, 0)\n        nn.init.normal_(self.fc_embed.weight, 0, 0.01)\n        nn.init.constant_(self.fc_embed.bias, 0)\n\n    def forward(self, x):\n        if self.num_convs > 0:\n            for i, conv in enumerate(self.convs):\n                x = conv(x)\n        x = x.view(x.size(0), -1)\n\n        if self.num_fcs > 0:\n            for i, fc in enumerate(self.fcs):\n                x = self.relu(fc(x))\n        x = self.fc_embed(x)\n        return x\n\n    def get_track_targets(self, gt_match_indices, key_sampling_results,\n                          ref_sampling_results):\n        track_targets = []\n        track_weights = []\n        for _gt_match_indices, key_res, ref_res in zip(gt_match_indices,\n                                                       key_sampling_results,\n                                                       ref_sampling_results):\n            targets = _gt_match_indices.new_zeros(\n                (key_res.pos_masks.size(0), ref_res.pos_masks.size(0)),\n                dtype=torch.int)\n            _match_indices = _gt_match_indices[key_res.pos_assigned_gt_inds]\n            pos2pos = (_match_indices.view(\n                -1, 1) == ref_res.pos_assigned_gt_inds.view(1, -1)).int()\n            targets[:, :pos2pos.size(1)] = pos2pos\n            weights = (targets.sum(dim=1) > 0).float()\n            track_targets.append(targets)\n            track_weights.append(weights)\n        return track_targets, track_weights\n\n    def match(self, key_embeds, ref_embeds, key_sampling_results,\n              ref_sampling_results):\n        num_key_rois = [res.pos_masks.size(0) for res in key_sampling_results]\n        key_embeds = torch.split(key_embeds, num_key_rois)\n        num_ref_rois = [res.pos_masks.size(0) for res in ref_sampling_results]\n        ref_embeds = torch.split(ref_embeds, num_ref_rois)\n\n        dists, cos_dists = [], []\n        for key_embed, ref_embed in zip(key_embeds, ref_embeds):\n            dist = cal_similarity(\n                key_embed,\n                ref_embed,\n                method='dot_product',\n                temperature=self.softmax_temp)\n            dists.append(dist)\n            if self.loss_track_aux is not None:\n                cos_dist = cal_similarity(\n                    key_embed, ref_embed, method='cosine')\n                cos_dists.append(cos_dist)\n            else:\n                cos_dists.append(None)\n        return dists, cos_dists\n\n    def loss(self, dists, cos_dists, targets, weights):\n        losses = dict()\n\n        loss_track = 0.\n        loss_track_aux = 0.\n        for _dists, _cos_dists, _targets, _weights in zip(\n                dists, cos_dists, targets, weights):\n            loss_track += self.loss_track(\n                _dists, _targets, _weights, avg_factor=_weights.sum())\n            if self.loss_track_aux is not None:\n                loss_track_aux += self.loss_track_aux(_cos_dists, _targets)\n        losses['loss_track'] = loss_track / len(dists)\n\n        if self.loss_track_aux is not None:\n            losses['loss_track_aux'] = loss_track_aux / len(dists)\n\n        return losses\n\n    @staticmethod\n    def random_choice(gallery, num):\n        \"\"\"Random select some elements from the gallery.\n\n        It seems that Pytorch's implementation is slower than numpy so we use\n        numpy to randperm the indices.\n        \"\"\"\n        assert len(gallery) >= num\n        if isinstance(gallery, list):\n            gallery = np.array(gallery)\n        cands = np.arange(len(gallery))\n        np.random.shuffle(cands)\n        rand_inds = cands[:num]\n        if not isinstance(gallery, np.ndarray):\n            rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)\n        return gallery[rand_inds]"
  },
  {
    "path": "knet/video/tracker.py",
    "content": "\"\"\"\nThis is a simple mask based tracker\nCopyright (c) https://github.com/xingyizhou/CenterTrack\nModified by Xiangtai Li\n\n\"\"\"\n# coding: utf-8\nimport torch\nfrom scipy.optimize import linear_sum_assignment\nfrom .util import generalized_box_iou, masks_to_boxes\nimport copy\n\n\nclass SimpleMaskTracker(object):\n    def __init__(self, score_thresh, max_age=32):\n        self.score_thresh = score_thresh\n        self.max_age = max_age\n        self.id_count = 0\n        self.tracks_dict = dict()\n        self.tracks = list()\n        self.unmatched_tracks = list()\n        self.reset_all()\n\n    def reset_all(self):\n        self.id_count = 0\n        self.tracks_dict = dict()\n        self.tracks = list()\n        self.unmatched_tracks = list()\n\n    def init_track(self, results):\n\n        scores = results[\"scores\"] # (n,)\n        masks = results[\"masks\"]  # (n,h,w)\n\n        ret = list()\n        ret_dict = dict()\n        for idx in range(scores.shape[0]):\n            if scores[idx] >= self.score_thresh:\n                self.id_count += 1\n                obj = dict()\n                obj[\"score\"] = float(scores[idx])\n                obj[\"mask\"] = masks[idx]\n                obj[\"tracking_id\"] = self.id_count\n                obj['active'] = 1\n                obj['age'] = 1\n                ret.append(obj)\n                ret_dict[idx] = obj\n\n        self.tracks = ret\n        self.tracks_dict = ret_dict\n        return copy.deepcopy(ret)\n\n    def step(self, output_results, track_results):\n        \"\"\"\n        Args:\n            output_results: Current Frame Output including the tracked results\n        Returns:\n        \"\"\"\n        scores = output_results[\"scores\"]  # (n,h,w)\n        bboxes = output_results[\"masks\"]  # (n,h,w)\n        # track_bboxes = track_results[\"masks\"]  # (m,h,w)\n\n        results = list()\n        results_dict = dict()\n\n        # tracks = list()\n        # for idx in range(scores.shape[0]):\n        #     if idx in self.tracks_dict and idx < len(track_bboxes):\n        #         self.tracks_dict[idx][\"mask\"] = track_bboxes[idx]\n        #\n        #     if scores[idx] >= self.score_thresh:\n        #         obj = dict()\n        #         obj[\"score\"] = float(scores[idx])\n        #         obj[\"mask\"] = bboxes[idx]\n        #         results.append(obj)\n        #         results_dict[idx] = obj\n\n        tracks = [v for v in self.tracks_dict.values()] + self.unmatched_tracks\n        N = len(results)\n        M = len(tracks)\n\n        ret = list()\n        unmatched_tracks = [t for t in range(M)]\n        unmatched_dets = [d for d in range(N)]\n        if N > 0 and M > 0:\n            det_box = masks_to_boxes(torch.stack([torch.tensor(obj['mask']) for obj in results], dim=0))  # N x h * w\n            track_box = masks_to_boxes(torch.stack([torch.tensor(obj['mask']) for obj in tracks], dim=0))  # M x h * w\n            cost_bbox = 1.0 - generalized_box_iou(det_box, track_box)  # N x M\n\n            matched_indices = linear_sum_assignment(cost_bbox)\n            unmatched_dets = [d for d in range(N) if not (d in matched_indices[0])]\n            unmatched_tracks = [d for d in range(M) if not (d in matched_indices[1])]\n\n            matches = [[], []]\n            for (m0, m1) in zip(matched_indices[0], matched_indices[1]):\n                if cost_bbox[m0, m1] > 1.2:\n                    unmatched_dets.append(m0)\n                    unmatched_tracks.append(m1)\n                else:\n                    matches[0].append(m0)\n                    matches[1].append(m1)\n\n            # handle the matched tracks\n            for (m0, m1) in zip(matches[0], matches[1]):\n                track = results[m0]\n                track['tracking_id'] = tracks[m1]['tracking_id']\n                track['age'] = 1\n                track['active'] = 1\n                ret.append(track)\n\n        for i in unmatched_dets:\n            track = results[i]\n            self.id_count += 1\n            track['tracking_id'] = self.id_count\n            track['age'] = 1\n            track['active'] = 1\n            ret.append(track)\n\n        curent_track = ret\n\n        # handle the remaining tracks\n        ret_unmatched_tracks = []\n        for i in unmatched_tracks:\n            track = tracks[i]\n            if track['age'] < self.max_age:\n                track['age'] += 1\n                track['active'] = 0\n                ret.append(track)\n                ret_unmatched_tracks.append(track)\n\n        self.tracks = ret\n        self.tracks_dict = results_dict\n        self.unmatched_tracks = ret_unmatched_tracks\n        return curent_track\n"
  },
  {
    "path": "knet/video/util.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nUtilities for bounding box manipulation and GIoU.\n\"\"\"\nimport torch\nfrom torchvision.ops.boxes import box_area\n\n\ndef box_cxcywh_to_xyxy(x):\n    x_c, y_c, w, h = x.unbind(-1)\n    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),\n         (x_c + 0.5 * w), (y_c + 0.5 * h)]\n    return torch.stack(b, dim=-1)\n\n\ndef box_xyxy_to_cxcywh(x):\n    x0, y0, x1, y1 = x.unbind(-1)\n    b = [(x0 + x1) / 2, (y0 + y1) / 2,\n         (x1 - x0), (y1 - y0)]\n    return torch.stack(b, dim=-1)\n\n\n# modified from torchvision to also return the union\ndef box_iou(boxes1, boxes2):\n    area1 = box_area(boxes1)\n    area2 = box_area(boxes2)\n\n    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]\n    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]\n\n    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]\n\n    union = area1[:, None] + area2 - inter\n\n    iou = inter / union\n    return iou, union\n\n\ndef generalized_box_iou(boxes1, boxes2):\n    \"\"\"\n    Generalized IoU from https://giou.stanford.edu/\n\n    The boxes should be in [x0, y0, x1, y1] format\n\n    Returns a [N, M] pairwise matrix, where N = len(boxes1)\n    and M = len(boxes2)\n    \"\"\"\n    # degenerate boxes gives inf / nan results\n    # so do an early check\n    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()\n    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()\n    iou, union = box_iou(boxes1, boxes2)\n\n    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])\n    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])\n\n    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n    area = wh[:, :, 0] * wh[:, :, 1]\n\n    return iou - (area - union) / area\n\n\ndef masks_to_boxes(masks):\n    \"\"\"Compute the bounding boxes around the provided masks\n\n    The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.\n\n    Returns a [N, 4] tensors, with the boxes in xyxy format\n    \"\"\"\n    if masks.numel() == 0:\n        return torch.zeros((0, 4), device=masks.device)\n\n    h, w = masks.shape[-2:]\n\n    y = torch.arange(0, h, dtype=torch.float)\n    x = torch.arange(0, w, dtype=torch.float)\n    y, x = torch.meshgrid(y, x)\n\n    x_mask = (masks * x.unsqueeze(0))\n    x_max = x_mask.flatten(1).max(-1)[0]\n    x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]\n\n    y_mask = (masks * y.unsqueeze(0))\n    y_max = y_mask.flatten(1).max(-1)[0]\n    y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]\n\n    return torch.stack([x_min, y_min, x_max, y_max], 1)\n"
  },
  {
    "path": "knet_vis/__init__.py",
    "content": ""
  },
  {
    "path": "knet_vis/det/__init__.py",
    "content": ""
  },
  {
    "path": "knet_vis/det/kernel_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, normal_init)\nfrom mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean\nfrom mmdet.models.builder import HEADS, build_loss, build_neck\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\n\n@HEADS.register_module()\nclass ConvKernelHead(nn.Module):\n\n    def __init__(self,\n                 num_proposals=100,\n                 in_channels=256,\n                 out_channels=256,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_seg_convs=1,\n                 num_loc_convs=1,\n                 att_dropout=False,\n                 localization_fpn=None,\n                 conv_kernel_size=1,\n                 norm_cfg=dict(type='GN', num_groups=32),\n                 semantic_fpn=True,\n                 train_cfg=None,\n                 num_classes=80,\n                 xavier_init_kernel=False,\n                 kernel_init_std=0.01,\n                 use_binary=False,\n                 proposal_feats_with_obj=False,\n                 loss_mask=None,\n                 loss_seg=None,\n                 loss_cls=None,\n                 loss_dice=None,\n                 loss_rank=None,\n                 feat_downsample_stride=1,\n                 feat_refine_stride=1,\n                 feat_refine=True,\n                 with_embed=False,\n                 feat_embed_only=False,\n                 conv_normal_init=False,\n                 mask_out_stride=4,\n                 hard_target=False,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cat_stuff_mask=False,\n                 **kwargs):\n        super(ConvKernelHead, self).__init__()\n        self.num_proposals = num_proposals\n        self.num_cls_fcs = num_cls_fcs\n        self.train_cfg = train_cfg\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_classes = num_classes\n        self.proposal_feats_with_obj = proposal_feats_with_obj\n        self.sampling = False\n        self.localization_fpn = build_neck(localization_fpn)\n        self.semantic_fpn = semantic_fpn\n        self.norm_cfg = norm_cfg\n        self.num_heads = num_heads\n        self.att_dropout = att_dropout\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.conv_kernel_size = conv_kernel_size\n        self.xavier_init_kernel = xavier_init_kernel\n        self.kernel_init_std = kernel_init_std\n        self.feat_downsample_stride = feat_downsample_stride\n        self.feat_refine_stride = feat_refine_stride\n        self.conv_normal_init = conv_normal_init\n        self.feat_refine = feat_refine\n        self.with_embed = with_embed\n        self.feat_embed_only = feat_embed_only\n        self.num_loc_convs = num_loc_convs\n        self.num_seg_convs = num_seg_convs\n        self.use_binary = use_binary\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n        self.cat_stuff_mask = cat_stuff_mask\n\n        if loss_mask is not None:\n            self.loss_mask = build_loss(loss_mask)\n        else:\n            self.loss_mask = loss_mask\n\n        if loss_dice is not None:\n            self.loss_dice = build_loss(loss_dice)\n        else:\n            self.loss_dice = loss_dice\n\n        if loss_seg is not None:\n            self.loss_seg = build_loss(loss_seg)\n        else:\n            self.loss_seg = loss_seg\n        if loss_cls is not None:\n            self.loss_cls = build_loss(loss_cls)\n        else:\n            self.loss_cls = loss_cls\n\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        if self.train_cfg:\n            self.assigner = build_assigner(self.train_cfg.assigner)\n            # use PseudoSampler when sampling is False\n            if self.sampling and hasattr(self.train_cfg, 'sampler'):\n                sampler_cfg = self.train_cfg.sampler\n            else:\n                sampler_cfg = dict(type='MaskPseudoSampler')\n            self.sampler = build_sampler(sampler_cfg, context=self)\n        self._init_layers()\n\n    def _init_layers(self):\n        \"\"\"Initialize a sparse set of proposal boxes and proposal features.\"\"\"\n        self.init_kernels = nn.Conv2d(\n            self.out_channels,\n            self.num_proposals,\n            self.conv_kernel_size,\n            padding=int(self.conv_kernel_size // 2),\n            bias=False)\n\n        if self.semantic_fpn:\n            if self.loss_seg.use_sigmoid:\n                self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes,\n                                          1)\n            else:\n                self.conv_seg = nn.Conv2d(self.out_channels,\n                                          self.num_classes + 1, 1)\n\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            self.ins_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,\n                padding=1,\n                norm_cfg=self.norm_cfg)\n            self.seg_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,\n                padding=1,\n                norm_cfg=self.norm_cfg)\n\n        self.loc_convs = nn.ModuleList()\n        for i in range(self.num_loc_convs):\n            self.loc_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n        self.seg_convs = nn.ModuleList()\n        for i in range(self.num_seg_convs):\n            self.seg_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        self.localization_fpn.init_weights()\n\n        if self.feat_downsample_stride > 1 and self.conv_normal_init:\n            logger = get_root_logger()\n            logger.info('Initialize convs in KPN head by normal std 0.01')\n            for conv in [self.loc_convs, self.seg_convs]:\n                for m in conv.modules():\n                    if isinstance(m, nn.Conv2d):\n                        normal_init(m, std=0.01)\n\n        if self.semantic_fpn:\n            bias_seg = bias_init_with_prob(0.01)\n            if self.loss_seg.use_sigmoid:\n                normal_init(self.conv_seg, std=0.01, bias=bias_seg)\n            else:\n                normal_init(self.conv_seg, mean=0, std=0.01)\n        if self.xavier_init_kernel:\n            logger = get_root_logger()\n            logger.info('Initialize kernels by xavier uniform')\n            nn.init.xavier_uniform_(self.init_kernels.weight)\n        else:\n            logger = get_root_logger()\n            logger.info(\n                f'Initialize kernels by normal std: {self.kernel_init_std}')\n            normal_init(self.init_kernels, mean=0, std=self.kernel_init_std)\n\n    def _decode_init_proposals(self, img, img_metas):\n        num_imgs = len(img_metas)\n\n        localization_feats = self.localization_fpn(img)\n        if isinstance(localization_feats, list):\n            loc_feats = localization_feats[0]\n        else:\n            loc_feats = localization_feats\n        for conv in self.loc_convs:\n            loc_feats = conv(loc_feats)\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            loc_feats = self.ins_downsample(loc_feats)\n        mask_preds = self.init_kernels(loc_feats)\n\n        if self.semantic_fpn:\n            if isinstance(localization_feats, list):\n                semantic_feats = localization_feats[1]\n            else:\n                semantic_feats = localization_feats\n            for conv in self.seg_convs:\n                semantic_feats = conv(semantic_feats)\n            if self.feat_downsample_stride > 1 and self.feat_refine:\n                semantic_feats = self.seg_downsample(semantic_feats)\n        else:\n            semantic_feats = None\n\n        if semantic_feats is not None:\n            seg_preds = self.conv_seg(semantic_feats)\n        else:\n            seg_preds = None\n\n        proposal_feats = self.init_kernels.weight.clone()\n        proposal_feats = proposal_feats[None].expand(num_imgs,\n                                                     *proposal_feats.size())\n\n        if semantic_feats is not None:\n            x_feats = semantic_feats + loc_feats\n        else:\n            x_feats = loc_feats\n\n        if self.proposal_feats_with_obj:\n            sigmoid_masks = mask_preds.sigmoid()\n            nonzero_inds = sigmoid_masks > 0.5\n            if self.use_binary:\n                sigmoid_masks = nonzero_inds.float()\n            else:\n                sigmoid_masks = nonzero_inds.float() * sigmoid_masks\n            obj_feats = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x_feats)\n\n        cls_scores = None\n\n        if self.proposal_feats_with_obj:\n            proposal_feats = proposal_feats + obj_feats.view(\n                num_imgs, self.num_proposals, self.out_channels, 1, 1)\n\n        if self.cat_stuff_mask and not self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs,\n                                                       *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n        \"\"\"Forward function in training stage.\"\"\"\n        num_imgs = len(img_metas)\n        results = self._decode_init_proposals(img, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results\n        if self.feat_downsample_stride > 1:\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=self.feat_downsample_stride,\n                mode='bilinear',\n                align_corners=False)\n            if seg_preds is not None:\n                scaled_seg_preds = F.interpolate(\n                    seg_preds,\n                    scale_factor=self.feat_downsample_stride,\n                    mode='bilinear',\n                    align_corners=False)\n        else:\n            scaled_mask_preds = mask_preds\n            scaled_seg_preds = seg_preds\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        sampling_results = []\n        if cls_scores is None:\n            detached_cls_scores = [None] * num_imgs\n        else:\n            detached_cls_scores = cls_scores.detach()\n\n        for i in range(num_imgs):\n            assign_result = self.assigner.assign(scaled_mask_preds[i].detach(),\n                                                 detached_cls_scores[i],\n                                                 gt_masks[i], gt_labels[i],\n                                                 img_metas[i])\n            sampling_result = self.sampler.sample(assign_result,\n                                                  scaled_mask_preds[i],\n                                                  gt_masks[i])\n            sampling_results.append(sampling_result)\n\n        mask_targets = self.get_targets(\n            sampling_results,\n            gt_masks,\n            self.train_cfg,\n            True,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls)\n\n        losses = self.loss(scaled_mask_preds, cls_scores, scaled_seg_preds,\n                           proposal_feats, *mask_targets)\n\n        if self.cat_stuff_mask and self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs,\n                                                       *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return losses, proposal_feats, x_feats, mask_preds, cls_scores\n\n    def loss(self,\n             mask_pred,\n             cls_scores,\n             seg_preds,\n             proposal_feats,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             seg_targets,\n             reduction_override=None,\n             **kwargs):\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n\n        if cls_scores is not None:\n            num_pos = pos_inds.sum().float()\n            avg_factor = reduce_mean(num_pos)\n            assert mask_pred.shape[0] == cls_scores.shape[0]\n            assert mask_pred.shape[1] == cls_scores.shape[1]\n            losses['loss_rpn_cls'] = self.loss_cls(\n                cls_scores.view(num_preds, -1),\n                labels,\n                label_weights,\n                avg_factor=avg_factor,\n                reduction_override=reduction_override)\n            losses['rpn_pos_acc'] = accuracy(\n                cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds])\n\n        bool_pos_inds = pos_inds.type(torch.bool)\n        # 0~self.num_classes-1 are FG, self.num_classes is BG\n        # do not perform bounding box regression for BG anymore.\n        H, W = mask_pred.shape[-2:]\n        if pos_inds.any():\n            pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds]\n            pos_mask_targets = mask_targets[bool_pos_inds]\n            losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n            losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n            if self.loss_rank is not None:\n                batch_size = mask_pred.size(0)\n                rank_target = mask_targets.new_full((batch_size, H, W),\n                                                    self.ignore_label,\n                                                    dtype=torch.long)\n                rank_inds = pos_inds.view(batch_size,\n                                          -1).nonzero(as_tuple=False)\n                batch_mask_targets = mask_targets.view(batch_size, -1, H,\n                                                       W).bool()\n                for i in range(batch_size):\n                    curr_inds = (rank_inds[:, 0] == i)\n                    curr_rank = rank_inds[:, 1][curr_inds]\n                    for j in curr_rank:\n                        rank_target[i][batch_mask_targets[i][j]] = j\n                losses['loss_rpn_rank'] = self.loss_rank(\n                    mask_pred, rank_target, ignore_index=self.ignore_label)\n\n        else:\n            losses['loss_rpn_mask'] = mask_pred.sum() * 0\n            losses['loss_rpn_dice'] = mask_pred.sum() * 0\n            if self.loss_rank is not None:\n                losses['loss_rank'] = mask_pred.sum() * 0\n\n        if seg_preds is not None:\n            if self.loss_seg.use_sigmoid:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(\n                    -1, cls_channel,\n                    H * W).permute(0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                num_dense_pos = (flatten_seg_target >= 0) & (\n                    flatten_seg_target < bg_class_ind)\n                num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0)\n                losses['loss_rpn_seg'] = self.loss_seg(\n                    flatten_seg,\n                    flatten_seg_target,\n                    avg_factor=num_dense_pos)\n            else:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(\n                    0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                losses['loss_rpn_seg'] = self.loss_seg(flatten_seg,\n                                                       flatten_seg_target)\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros(num_samples)\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        seg_targets = pos_mask.new_full((H, W),\n                                        self.num_classes,\n                                        dtype=torch.long)\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            gt_sem_seg = gt_sem_seg.bool()\n            for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls):\n                seg_targets[sem_mask] = sem_cls.long()\n\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            mask_targets[pos_inds, ...] = pos_gt_mask\n            mask_weights[pos_inds, ...] = 1\n            for i in range(num_pos):\n                seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i]\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def get_targets(self,\n                    sampling_results,\n                    gt_mask,\n                    rpn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        num_imgs = len(sampling_results)\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * num_imgs\n            gt_sem_cls = [None] * num_imgs\n        results = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rpn_train_cfg)\n        (labels, label_weights, mask_targets, mask_weights,\n         seg_targets) = results\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n            seg_targets = torch.stack(seg_targets, 0)\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def simple_test_rpn(self, img, img_metas):\n        \"\"\"Forward function in testing stage.\"\"\"\n        return self._decode_init_proposals(img, img_metas)\n\n    def forward_dummy(self, img, img_metas):\n        \"\"\"Dummy forward function.\n\n        Used in flops calculation.\n        \"\"\"\n        return self._decode_init_proposals(img, img_metas)\n"
  },
  {
    "path": "knet_vis/det/kernel_iter_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mmdet.core import build_assigner, build_sampler\nfrom mmdet.datasets.coco_panoptic import INSTANCE_OFFSET\nfrom mmdet.models.builder import HEADS, build_head\nfrom mmdet.models.roi_heads import BaseRoIHead\nfrom .mask_pseudo_sampler import MaskPseudoSampler\n\n\n@HEADS.register_module()\nclass KernelIterHead(BaseRoIHead):\n\n    def __init__(self,\n                 num_stages=6,\n                 recursive=False,\n                 assign_stages=5,\n                 stage_loss_weights=(1, 1, 1, 1, 1, 1),\n                 proposal_feature_channel=256,\n                 merge_cls_scores=False,\n                 do_panoptic=False,\n                 post_assign=False,\n                 hard_target=False,\n                 num_proposals=100,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 thing_label_in_seg=0,\n                 mask_head=dict(\n                     type='KernelUpdateHead',\n                     num_classes=80,\n                     num_fcs=2,\n                     num_heads=8,\n                     num_cls_fcs=1,\n                     num_reg_fcs=3,\n                     feedforward_channels=2048,\n                     hidden_channels=256,\n                     dropout=0.0,\n                     roi_feat_size=7,\n                     ffn_act_cfg=dict(type='ReLU', inplace=True)),\n                 mask_out_stride=4,\n                 train_cfg=None,\n                 test_cfg=None,\n                 **kwargs):\n        assert mask_head is not None\n        assert len(stage_loss_weights) == num_stages\n        self.num_stages = num_stages\n        self.stage_loss_weights = stage_loss_weights\n        self.proposal_feature_channel = proposal_feature_channel\n        self.merge_cls_scores = merge_cls_scores\n        self.recursive = recursive\n        self.post_assign = post_assign\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.assign_stages = assign_stages\n        self.do_panoptic = do_panoptic\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.num_classes = num_thing_classes + num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.num_proposals = num_proposals\n        super(KernelIterHead, self).__init__(\n            mask_head=mask_head,\n            train_cfg=train_cfg,\n            test_cfg=test_cfg,\n            **kwargs)\n        # train_cfg would be None when run the test.py\n        if train_cfg is not None:\n            for stage in range(num_stages):\n                assert isinstance(\n                    self.mask_sampler[stage], MaskPseudoSampler), \\\n                    'Sparse Mask only support `MaskPseudoSampler`'\n\n    def init_bbox_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize box head and box roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of box roi extractor.\n            mask_head (dict): Config of box in box head.\n        \"\"\"\n        pass\n\n    def init_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler for each stage.\"\"\"\n        self.mask_assigner = []\n        self.mask_sampler = []\n        if self.train_cfg is not None:\n            for idx, rcnn_train_cfg in enumerate(self.train_cfg):\n                self.mask_assigner.append(\n                    build_assigner(rcnn_train_cfg.assigner))\n                self.current_stage = idx\n                self.mask_sampler.append(\n                    build_sampler(rcnn_train_cfg.sampler, context=self))\n\n    def init_weights(self):\n        for i in range(self.num_stages):\n            self.mask_head[i].init_weights()\n\n    def init_mask_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize mask head and mask roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of mask roi extractor.\n            mask_head (dict): Config of mask in mask head.\n        \"\"\"\n        self.mask_head = nn.ModuleList()\n        if not isinstance(mask_head, list):\n            mask_head = [mask_head for _ in range(self.num_stages)]\n        assert len(mask_head) == self.num_stages\n        for head in mask_head:\n            self.mask_head.append(build_head(head))\n        if self.recursive:\n            for i in range(self.num_stages):\n                self.mask_head[i] = self.mask_head[0]\n\n    def _mask_forward(self, stage, x, object_feats, mask_preds, img_metas):\n        mask_head = self.mask_head[stage]\n        cls_score, mask_preds, object_feats = mask_head(\n            x, object_feats, mask_preds, img_metas=img_metas)\n        if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1\n                                                   or self.training):\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=mask_head.mask_upsample_stride,\n                align_corners=False,\n                mode='bilinear')\n        else:\n            scaled_mask_preds = mask_preds\n        mask_results = dict(\n            cls_score=cls_score,\n            mask_preds=mask_preds,\n            scaled_mask_preds=scaled_mask_preds,\n            object_feats=object_feats)\n\n        return mask_results\n\n    def forward_train(self,\n                      x,\n                      proposal_feats,\n                      mask_preds,\n                      cls_score,\n                      img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_bboxes_ignore=None,\n                      imgs_whwh=None,\n                      gt_bboxes=None,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n\n        num_imgs = len(img_metas)\n        if self.mask_head[0].mask_upsample_stride > 1:\n            prev_mask_preds = F.interpolate(\n                mask_preds.detach(),\n                scale_factor=self.mask_head[0].mask_upsample_stride,\n                mode='bilinear',\n                align_corners=False)\n        else:\n            prev_mask_preds = mask_preds.detach()\n\n        if cls_score is not None:\n            prev_cls_score = cls_score.detach()\n        else:\n            prev_cls_score = [None] * num_imgs\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        object_feats = proposal_feats\n        all_stage_loss = {}\n        all_stage_mask_results = []\n        assign_results = []\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score']\n            object_feats = mask_results['object_feats']\n\n            if self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            sampling_results = []\n            if stage < self.assign_stages:\n                assign_results = []\n            for i in range(num_imgs):\n                if stage < self.assign_stages:\n                    mask_for_assign = prev_mask_preds[i][:self.num_proposals]\n                    if prev_cls_score[i] is not None:\n                        cls_for_assign = prev_cls_score[\n                            i][:self.num_proposals, :self.num_thing_classes]\n                    else:\n                        cls_for_assign = None\n                    assign_result = self.mask_assigner[stage].assign(\n                        mask_for_assign, cls_for_assign, gt_masks[i],\n                        gt_labels[i], img_metas[i])\n                    assign_results.append(assign_result)\n                sampling_result = self.mask_sampler[stage].sample(\n                    assign_results[i], scaled_mask_preds[i], gt_masks[i])\n                sampling_results.append(sampling_result)\n            mask_targets = self.mask_head[stage].get_targets(\n                sampling_results,\n                self.train_cfg[stage],\n                True,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls)\n\n            single_stage_loss = self.mask_head[stage].loss(\n                object_feats,\n                cls_score,\n                scaled_mask_preds,\n                *mask_targets,\n                imgs_whwh=imgs_whwh)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f's{stage}_{key}'] = value * \\\n                                    self.stage_loss_weights[stage]\n\n            if not self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n        return all_stage_loss\n\n    def simple_test(self,\n                    x,\n                    proposal_feats,\n                    mask_preds,\n                    cls_score,\n                    img_metas,\n                    imgs_whwh=None,\n                    rescale=False):\n\n        # Decode initial proposals\n        num_imgs = len(img_metas)\n        # num_proposals = proposal_feats.size(1)\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n\n        num_classes = self.mask_head[-1].num_classes\n        results = []\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        if self.do_panoptic:\n            for img_id in range(num_imgs):\n                single_result = self.get_panoptic(cls_score[img_id],\n                                                  scaled_mask_preds[img_id],\n                                                  self.test_cfg,\n                                                  img_metas[img_id])\n                results.append(single_result)\n        else:\n            for img_id in range(num_imgs):\n                cls_score_per_img = cls_score[img_id]\n                # h, quite tricky here, a bounding box can predict multiple results with different labels\n                scores_per_img, topk_indices = cls_score_per_img.flatten(0, 1).topk(\n                        self.test_cfg.max_per_img, sorted=True)\n                mask_indices = topk_indices // num_classes\n                # Use the following when torch >= 1.9.0\n                # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='trunc')\n                labels_per_img = topk_indices % num_classes\n                masks_per_img = scaled_mask_preds[img_id][mask_indices]\n                single_result = self.mask_head[-1].get_seg_masks(\n                    masks_per_img, labels_per_img, scores_per_img,\n                    self.test_cfg, img_metas[img_id])\n                results.append(single_result)\n        return results\n\n    def aug_test(self, features, proposal_list, img_metas, rescale=False):\n        raise NotImplementedError('SparseMask does not support `aug_test`')\n\n    def forward_dummy(self, x, proposal_boxes, proposal_feats, img_metas):\n        \"\"\"Dummy forward function when do the flops computing.\"\"\"\n        all_stage_mask_results = []\n        num_imgs = len(img_metas)\n        num_proposals = proposal_feats.size(1)\n        C, H, W = x.shape[-3:]\n        mask_preds = proposal_feats.bmm(x.view(num_imgs, C, -1)).view(\n            num_imgs, num_proposals, H, W)\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n        return all_stage_mask_results\n\n    def get_panoptic(self, cls_scores, mask_preds, test_cfg, img_meta):\n        # resize mask predictions back\n        scores = cls_scores[:self.num_proposals][:, :self.num_thing_classes]\n        thing_scores, thing_labels = scores.max(dim=1)\n        stuff_scores = cls_scores[\n            self.num_proposals:][:, self.num_thing_classes:].diag()\n        stuff_labels = torch.arange(\n            0, self.num_stuff_classes) + self.num_thing_classes\n        stuff_labels = stuff_labels.to(thing_labels.device)\n\n        total_masks = self.mask_head[-1].rescale_masks(mask_preds, img_meta)\n        total_scores = torch.cat([thing_scores, stuff_scores], dim=0)\n        total_labels = torch.cat([thing_labels, stuff_labels], dim=0)\n\n        panoptic_result = self.merge_stuff_thing(total_masks, total_labels,\n                                                 total_scores,\n                                                 test_cfg.merge_stuff_thing)\n        return dict(pan_results=panoptic_result)\n\n    def merge_stuff_thing(self,\n                          total_masks,\n                          total_labels,\n                          total_scores,\n                          merge_cfg=None):\n\n        H, W = total_masks.shape[-2:]\n        panoptic_seg = total_masks.new_full((H, W),\n                                            self.num_classes,\n                                            dtype=torch.long)\n\n        cur_prob_masks = total_scores.view(-1, 1, 1) * total_masks\n        cur_mask_ids = cur_prob_masks.argmax(0)\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-total_scores)\n        current_segment_id = 0\n\n        for k in sorted_inds:\n            pred_class = total_labels[k].item()\n            isthing = pred_class < self.num_thing_classes\n            if isthing and total_scores[k] < merge_cfg.instance_score_thr:\n                continue\n\n            mask = cur_mask_ids == k\n            mask_area = mask.sum().item()\n            original_area = (total_masks[k] >= 0.5).sum().item()\n\n            if mask_area > 0 and original_area > 0:\n                if mask_area / original_area < merge_cfg.overlap_thr:\n                    continue\n\n                panoptic_seg[mask] = total_labels[k] \\\n                    + current_segment_id * INSTANCE_OFFSET\n                current_segment_id += 1\n\n        return panoptic_seg.cpu().numpy()\n"
  },
  {
    "path": "knet_vis/det/kernel_update_head.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, build_activation_layer,\n                      build_norm_layer)\nfrom mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,\n                                         build_transformer_layer)\nfrom mmcv.runner import force_fp32\n\nfrom mmdet.core import multi_apply\nfrom mmdet.models.builder import HEADS, build_loss\nfrom mmdet.models.dense_heads.atss_head import reduce_mean\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\nfrom mmtrack.transform import outs2results\n\n@HEADS.register_module()\nclass KernelUpdateHead(nn.Module):\n\n    def __init__(self,\n                 num_classes=80,\n                 num_ffn_fcs=2,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_mask_fcs=3,\n                 feedforward_channels=2048,\n                 in_channels=256,\n                 out_channels=256,\n                 dropout=0.0,\n                 mask_thr=0.5,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 ffn_act_cfg=dict(type='ReLU', inplace=True),\n                 conv_kernel_size=3,\n                 feat_transform_cfg=None,\n                 hard_mask_thr=0.5,\n                 kernel_init=False,\n                 with_ffn=True,\n                 mask_out_stride=4,\n                 relative_coors=False,\n                 relative_coors_off=False,\n                 feat_gather_stride=1,\n                 mask_transform_stride=1,\n                 mask_upsample_stride=1,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 kernel_updator_cfg=dict(\n                     type='DynamicConv',\n                     in_channels=256,\n                     feat_channels=64,\n                     out_channels=256,\n                     input_feat_shape=1,\n                     act_cfg=dict(type='ReLU', inplace=True),\n                     norm_cfg=dict(type='LN')),\n                 loss_rank=None,\n                 loss_mask=dict(\n                     type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),\n                 loss_dice=dict(type='DiceLoss', loss_weight=3.0),\n                 loss_cls=dict(\n                     type='FocalLoss',\n                     use_sigmoid=True,\n                     gamma=2.0,\n                     alpha=0.25,\n                     loss_weight=2.0)):\n        super(KernelUpdateHead, self).__init__()\n        self.num_classes = num_classes\n        self.loss_cls = build_loss(loss_cls)\n        self.loss_mask = build_loss(loss_mask)\n        self.loss_dice = build_loss(loss_dice)\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.mask_thr = mask_thr\n        self.fp16_enabled = False\n        self.dropout = dropout\n\n        self.num_heads = num_heads\n        self.hard_mask_thr = hard_mask_thr\n        self.kernel_init = kernel_init\n        self.with_ffn = with_ffn\n        self.mask_out_stride = mask_out_stride\n        self.relative_coors = relative_coors\n        self.relative_coors_off = relative_coors_off\n        self.conv_kernel_size = conv_kernel_size\n        self.feat_gather_stride = feat_gather_stride\n        self.mask_transform_stride = mask_transform_stride\n        self.mask_upsample_stride = mask_upsample_stride\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n\n        self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,\n                                            num_heads, dropout)\n        self.attention_norm = build_norm_layer(\n            dict(type='LN'), in_channels * conv_kernel_size**2)[1]\n\n        self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)\n\n        if feat_transform_cfg is not None:\n            kernel_size = feat_transform_cfg.pop('kernel_size', 1)\n            self.feat_transform = ConvModule(\n                in_channels,\n                in_channels,\n                kernel_size,\n                stride=feat_gather_stride,\n                padding=int(feat_gather_stride // 2),\n                **feat_transform_cfg)\n        else:\n            self.feat_transform = None\n\n        if self.with_ffn:\n            self.ffn = FFN(\n                in_channels,\n                feedforward_channels,\n                num_ffn_fcs,\n                act_cfg=ffn_act_cfg,\n                ffn_drop=dropout)\n            self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n        self.cls_fcs = nn.ModuleList()\n        for _ in range(num_cls_fcs):\n            self.cls_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.cls_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.cls_fcs.append(build_activation_layer(act_cfg))\n\n        if self.loss_cls.use_sigmoid:\n            self.fc_cls = nn.Linear(in_channels, self.num_classes)\n        else:\n            self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)\n\n        self.mask_fcs = nn.ModuleList()\n        for _ in range(num_mask_fcs):\n            self.mask_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.mask_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.mask_fcs.append(build_activation_layer(act_cfg))\n\n        self.fc_mask = nn.Linear(in_channels, out_channels)\n\n    def init_weights(self):\n        \"\"\"Use xavier initialization for all weight parameter and set\n        classification head bias as a specific value when use focal loss.\"\"\"\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n            else:\n                # adopt the default initialization for\n                # the weight and bias of the layer norm\n                pass\n        if self.loss_cls.use_sigmoid:\n            bias_init = bias_init_with_prob(0.01)\n            nn.init.constant_(self.fc_cls.bias, bias_init)\n        if self.kernel_init:\n            logger = get_root_logger()\n            logger.info(\n                'mask kernel in mask head is normal initialized by std 0.01')\n            nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)\n\n    def forward(self,\n                x,\n                proposal_feat,\n                mask_preds,\n                prev_cls_score=None,\n                mask_shape=None,\n                img_metas=None):\n\n        N, num_proposals = proposal_feat.shape[:2]\n        if self.feat_transform is not None:\n            x = self.feat_transform(x)\n        C, H, W = x.shape[-3:]\n\n        mask_h, mask_w = mask_preds.shape[-2:]\n        if mask_h != H or mask_w != W:\n            gather_mask = F.interpolate(\n                mask_preds, (H, W), align_corners=False, mode='bilinear')\n        else:\n            gather_mask = mask_preds\n\n        sigmoid_masks = gather_mask.sigmoid()\n        nonzero_inds = sigmoid_masks > self.hard_mask_thr\n        sigmoid_masks = nonzero_inds.float()\n\n        # einsum is faster than bmm by 30%\n        x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)\n\n        # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]\n        proposal_feat = proposal_feat.reshape(N, num_proposals,\n                                              self.in_channels,\n                                              -1).permute(0, 1, 3, 2)\n        obj_feat = self.kernel_update_conv(x_feat, proposal_feat)\n\n        # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]\n        obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)\n        obj_feat = self.attention_norm(self.attention(obj_feat))\n        # [N, B, K*K*C] -> [B, N, K*K*C]\n        obj_feat = obj_feat.permute(1, 0, 2)\n\n        # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n        obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n\n        # FFN\n        if self.with_ffn:\n            obj_feat = self.ffn_norm(self.ffn(obj_feat))\n\n        cls_feat = obj_feat.sum(-2)\n        mask_feat = obj_feat\n\n        for cls_layer in self.cls_fcs:\n            cls_feat = cls_layer(cls_feat)\n        for reg_layer in self.mask_fcs:\n            mask_feat = reg_layer(mask_feat)\n\n        cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)\n        # [B, N, K*K, C] -> [B, N, C, K*K]\n        mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)\n\n        if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):\n            mask_x = F.interpolate(\n                x, scale_factor=0.5, mode='bilinear', align_corners=False)\n            H, W = mask_x.shape[-2:]\n            raise NotImplementedError\n        else:\n            mask_x = x\n        # group conv is 5x faster than unfold and uses about 1/5 memory\n        # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms\n        # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369\n        # fold_x = F.unfold(\n        #     mask_x,\n        #     self.conv_kernel_size,\n        #     padding=int(self.conv_kernel_size // 2))\n        # mask_feat = mask_feat.reshape(N, num_proposals, -1)\n        # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)\n        # [B, N, C, K*K] -> [B*N, C, K, K]\n        mask_feat = mask_feat.reshape(N, num_proposals, C,\n                                      self.conv_kernel_size,\n                                      self.conv_kernel_size)\n        # [B, C, H, W] -> [1, B*C, H, W]\n        new_mask_preds = []\n        for i in range(N):\n            new_mask_preds.append(\n                F.conv2d(\n                    mask_x[i:i + 1],\n                    mask_feat[i],\n                    padding=int(self.conv_kernel_size // 2)))\n\n        new_mask_preds = torch.cat(new_mask_preds, dim=0)\n        new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)\n        if self.mask_transform_stride == 2:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                scale_factor=2,\n                mode='bilinear',\n                align_corners=False)\n\n        if mask_shape is not None and mask_shape[0] != H:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                mask_shape,\n                align_corners=False,\n                mode='bilinear')\n\n        return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(\n            N, num_proposals, self.in_channels, self.conv_kernel_size,\n            self.conv_kernel_size)\n\n    @force_fp32(apply_to=('cls_score', 'mask_pred'))\n    def loss(self,\n             object_feats,\n             cls_score,\n             mask_pred,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             imgs_whwh=None,\n             reduction_override=None,\n             **kwargs):\n\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_pos = pos_inds.sum().float()\n        avg_factor = reduce_mean(num_pos).clamp_(min=1.0)\n\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n        assert mask_pred.shape[0] == cls_score.shape[0]\n        assert mask_pred.shape[1] == cls_score.shape[1]\n\n        if cls_score is not None:\n            if cls_score.numel() > 0:\n                losses['loss_cls'] = self.loss_cls(\n                    cls_score.view(num_preds, -1),\n                    labels,\n                    label_weights,\n                    avg_factor=avg_factor,\n                    reduction_override=reduction_override)\n                losses['pos_acc'] = accuracy(\n                    cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds])\n        if mask_pred is not None:\n            bool_pos_inds = pos_inds.type(torch.bool)\n            # 0~self.num_classes-1 are FG, self.num_classes is BG\n            # do not perform bounding box regression for BG anymore.\n            H, W = mask_pred.shape[-2:]\n            if pos_inds.any():\n                pos_mask_pred = mask_pred.reshape(num_preds, H,\n                                                  W)[bool_pos_inds]\n                pos_mask_targets = mask_targets[bool_pos_inds]\n                losses['loss_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n                losses['loss_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n                if self.loss_rank is not None:\n                    batch_size = mask_pred.size(0)\n                    rank_target = mask_targets.new_full((batch_size, H, W),\n                                                        self.ignore_label,\n                                                        dtype=torch.long)\n                    rank_inds = pos_inds.view(batch_size,\n                                              -1).nonzero(as_tuple=False)\n                    batch_mask_targets = mask_targets.view(\n                        batch_size, -1, H, W).bool()\n                    for i in range(batch_size):\n                        curr_inds = (rank_inds[:, 0] == i)\n                        curr_rank = rank_inds[:, 1][curr_inds]\n                        for j in curr_rank:\n                            rank_target[i][batch_mask_targets[i][j]] = j\n                    losses['loss_rank'] = self.loss_rank(\n                        mask_pred, rank_target, ignore_index=self.ignore_label)\n            else:\n                losses['loss_mask'] = mask_pred.sum() * 0\n                losses['loss_dice'] = mask_pred.sum() * 0\n                if self.loss_rank is not None:\n                    losses['loss_rank'] = mask_pred.sum() * 0\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros((num_samples, self.num_classes))\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            pos_mask_targets = pos_gt_mask\n            mask_targets[pos_inds, ...] = pos_mask_targets\n            mask_weights[pos_inds, ...] = 1\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            sem_labels = pos_mask.new_full((self.num_stuff_classes, ),\n                                           self.num_classes,\n                                           dtype=torch.long)\n            sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_stuff_weights = torch.eye(\n                self.num_stuff_classes, device=pos_mask.device)\n            sem_thing_weights = pos_mask.new_zeros(\n                (self.num_stuff_classes, self.num_thing_classes))\n            sem_label_weights = torch.cat(\n                [sem_thing_weights, sem_stuff_weights], dim=-1)\n            if len(gt_sem_cls > 0):\n                sem_inds = gt_sem_cls - self.num_thing_classes\n                sem_inds = sem_inds.long()\n                sem_labels[sem_inds] = gt_sem_cls.long()\n                sem_targets[sem_inds] = gt_sem_seg\n                sem_weights[sem_inds] = 1\n\n            label_weights[:, self.num_thing_classes:] = 0\n            labels = torch.cat([labels, sem_labels])\n            label_weights = torch.cat([label_weights, sem_label_weights])\n            mask_targets = torch.cat([mask_targets, sem_targets])\n            mask_weights = torch.cat([mask_weights, sem_weights])\n\n        return labels, label_weights, mask_targets, mask_weights\n\n    def get_targets(self,\n                    sampling_results,\n                    rcnn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        num_imgs = len(sampling_results)\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * num_imgs\n            gt_sem_cls = [None] * num_imgs\n\n        labels, label_weights, mask_targets, mask_weights = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rcnn_train_cfg)\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n        return labels, label_weights, mask_targets, mask_weights\n\n    def rescale_masks(self, masks_per_img, img_meta):\n        h, w, _ = img_meta['img_shape']\n        masks_per_img = F.interpolate(\n            masks_per_img.unsqueeze(0).sigmoid(),\n            size=img_meta['batch_input_shape'],\n            mode='bilinear',\n            align_corners=False)\n\n        masks_per_img = masks_per_img[:, :, :h, :w]\n        ori_shape = img_meta['ori_shape']\n        seg_masks = F.interpolate(\n            masks_per_img,\n            size=ori_shape[:2],\n            mode='bilinear',\n            align_corners=False).squeeze(0)\n        return seg_masks\n\n    def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img,\n                      test_cfg, img_meta):\n        # resize mask predictions back\n        seg_masks = self.rescale_masks(masks_per_img, img_meta)\n        seg_masks = seg_masks > test_cfg.mask_thr\n        bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img,\n                                                    scores_per_img)\n        return bbox_result, segm_result\n\n    def segm2result(self, mask_preds, det_labels, cls_scores):\n        num_classes = self.num_classes\n        bbox_result = None\n        segm_result = [[] for _ in range(num_classes)]\n        mask_preds = mask_preds.cpu().numpy()\n        det_labels = det_labels.cpu().numpy()\n        cls_scores = cls_scores.cpu().numpy()\n        num_ins = mask_preds.shape[0]\n        # fake bboxes\n        bboxes = np.zeros((num_ins, 5), dtype=np.float32)\n        bboxes[:, -1] = cls_scores\n        bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)]\n        for idx in range(num_ins):\n            segm_result[det_labels[idx]].append(mask_preds[idx])\n        return bbox_result, segm_result\n\n    def get_seg_masks_tracking(self, masks_per_img, labels_per_img, scores_per_img, ids_per_img,\n                      test_cfg, img_meta):\n        num_ins = masks_per_img.shape[0]\n        # resize mask predictions back\n        seg_masks = self.rescale_masks(masks_per_img, img_meta)\n        seg_masks = seg_masks > test_cfg.mask_thr\n        # fake bboxes\n        bboxes = torch.zeros((num_ins, 5), dtype=torch.float32)\n        bboxes[:, -1] = scores_per_img\n        tracks = outs2results(\n            bboxes=bboxes,\n            labels=labels_per_img,\n            masks=seg_masks,\n            ids=ids_per_img,\n            num_classes=self.num_classes,\n        )\n        return tracks['bbox_results'], tracks['mask_results']\n"
  },
  {
    "path": "knet_vis/det/knet.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import TwoStageDetector\nfrom mmdet.utils import get_root_logger\nfrom .utils import sem2ins_masks\n\n\n@DETECTORS.register_module()\nclass KNet(TwoStageDetector):\n\n    def __init__(self,\n                 *args,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 thing_label_in_seg=0,\n                 **kwargs):\n        super(KNet, self).__init__(*args, **kwargs)\n        assert self.with_rpn, 'KNet does not support external proposals'\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        logger = get_root_logger()\n        logger.info(f'Model: \\n{self}')\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      proposals=None,\n                      gt_semantic_seg=None,\n                      **kwargs):\n\n        super(TwoStageDetector, self).forward_train(img, img_metas)\n        assert proposals is None, 'KNet does not support' \\\n                                  ' external proposals'\n        assert gt_masks is not None\n\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by 255 and\n                # zero indicating the first class\n                sem_labels, sem_seg = sem2ins_masks(\n                    gt_semantic_seg[i],\n                    num_thing_classes=self.num_thing_classes)\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W)))\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0])\n                gt_sem_cls.append(sem_labels)\n\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W)))\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),\n                        mode='bilinear',\n                        align_corners=False)[0])\n\n        gt_masks = gt_masks_tensor\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.forward_train(x, img_metas, gt_masks,\n                                                  gt_labels, gt_sem_seg,\n                                                  gt_sem_cls)\n        (rpn_losses, proposal_feats, x_feats, mask_preds,\n         cls_scores) = rpn_results\n\n        losses = self.roi_head.forward_train(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            gt_masks,\n            gt_labels,\n            gt_bboxes_ignore=gt_bboxes_ignore,\n            gt_bboxes=gt_bboxes,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls,\n            imgs_whwh=None)\n\n        losses.update(rpn_losses)\n        return losses\n\n    def simple_test(self, img, img_metas, rescale=False):\n        x = self.extract_feat(img)\n        rpn_results = self.rpn_head.simple_test_rpn(x, img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        segm_results = self.roi_head.simple_test(\n            x_feats,\n            proposal_feats,\n            mask_preds,\n            cls_scores,\n            img_metas,\n            imgs_whwh=None,\n            rescale=rescale)\n        return segm_results\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        # roi_head\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats,\n                                               dummy_img_metas)\n        return roi_outs\n"
  },
  {
    "path": "knet_vis/det/mask_hungarian_assigner.py",
    "content": "import numpy as np\nimport torch\n\nfrom mmdet.core import AssignResult, BaseAssigner\nfrom mmdet.core.bbox.builder import BBOX_ASSIGNERS\nfrom mmdet.core.bbox.match_costs.builder import MATCH_COST, build_match_cost\n\ntry:\n    from scipy.optimize import linear_sum_assignment\nexcept ImportError:\n    linear_sum_assignment = None\n\n\n@MATCH_COST.register_module()\nclass DiceCost(object):\n    \"\"\"DiceCost.\n\n     Args:\n         weight (int | float, optional): loss_weight\n         pred_act (bool): Whether to activate the prediction\n            before calculating cost\n\n     Examples:\n         >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost\n         >>> import torch\n         >>> self = BBoxL1Cost()\n         >>> bbox_pred = torch.rand(1, 4)\n         >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])\n         >>> factor = torch.tensor([10, 8, 10, 8])\n         >>> self(bbox_pred, gt_bboxes, factor)\n         tensor([[1.6172, 1.6422]])\n    \"\"\"\n\n    def __init__(self,\n                 weight=1.,\n                 pred_act=False,\n                 act_mode='sigmoid',\n                 eps=1e-3):\n        self.weight = weight\n        self.pred_act = pred_act\n        self.act_mode = act_mode\n        self.eps = eps\n\n    def dice_loss(cls, input, target, eps=1e-3):\n        input = input.reshape(input.size()[0], -1)\n        target = target.reshape(target.size()[0], -1).float()\n        # einsum saves 10x memory\n        # a = torch.sum(input[:, None] * target[None, ...], -1)\n        a = torch.einsum('nh,mh->nm', input, target)\n        b = torch.sum(input * input, 1) + eps\n        c = torch.sum(target * target, 1) + eps\n        d = (2 * a) / (b[:, None] + c[None, ...])\n        # 1 is a constance that will not affect the matching, so ommitted\n        return -d\n\n    def __call__(self, mask_preds, gt_masks):\n        \"\"\"\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            gt_bboxes (Tensor): Ground truth boxes with normalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n\n        Returns:\n            torch.Tensor: bbox_cost value with weight\n        \"\"\"\n        if self.pred_act and self.act_mode == 'sigmoid':\n            mask_preds = mask_preds.sigmoid()\n        elif self.pred_act:\n            mask_preds = mask_preds.softmax(dim=0)\n        dice_cost = self.dice_loss(mask_preds, gt_masks, self.eps)\n        return dice_cost * self.weight\n\n\n@MATCH_COST.register_module()\nclass MaskCost(object):\n    \"\"\"MaskCost.\n\n    Args:\n        weight (int | float, optional): loss_weight\n    \"\"\"\n\n    def __init__(self, weight=1., pred_act=False, act_mode='sigmoid'):\n        self.weight = weight\n        self.pred_act = pred_act\n        self.act_mode = act_mode\n\n    def __call__(self, cls_pred, target):\n        \"\"\"\n        Args:\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n\n        Returns:\n            torch.Tensor: cls_cost value with weight\n        \"\"\"\n        if self.pred_act and self.act_mode == 'sigmoid':\n            cls_pred = cls_pred.sigmoid()\n        elif self.pred_act:\n            cls_pred = cls_pred.softmax(dim=0)\n\n        _, H, W = target.shape\n        # flatten_cls_pred = cls_pred.view(num_proposals, -1)\n        # eingum is ~10 times faster than matmul\n        pos_cost = torch.einsum('nhw,mhw->nm', cls_pred, target)\n        neg_cost = torch.einsum('nhw,mhw->nm', 1 - cls_pred, 1 - target)\n        cls_cost = -(pos_cost + neg_cost) / (H * W)\n        return cls_cost * self.weight\n\n\n@BBOX_ASSIGNERS.register_module()\nclass MaskHungarianAssigner(BaseAssigner):\n    \"\"\"Computes one-to-one matching between predictions and ground truth.\n\n    This class computes an assignment between the targets and the predictions\n    based on the costs. The costs are weighted sum of three components:\n    classfication cost, regression L1 cost and regression iou cost. The\n    targets don't include the no_object, so generally there are more\n    predictions than targets. After the one-to-one matching, the un-matched\n    are treated as backgrounds. Thus each query prediction will be assigned\n    with `0` or a positive integer indicating the ground truth index:\n\n    - 0: negative sample, no assigned gt\n    - positive integer: positive sample, index (1-based) of assigned gt\n\n    Args:\n        cls_weight (int | float, optional): The scale factor for classification\n            cost. Default 1.0.\n        bbox_weight (int | float, optional): The scale factor for regression\n            L1 cost. Default 1.0.\n        iou_weight (int | float, optional): The scale factor for regression\n            iou cost. Default 1.0.\n        iou_calculator (dict | optional): The config for the iou calculation.\n            Default type `BboxOverlaps2D`.\n        iou_mode (str | optional): \"iou\" (intersection over union), \"iof\"\n                (intersection over foreground), or \"giou\" (generalized\n                intersection over union). Default \"giou\".\n    \"\"\"\n\n    def __init__(self,\n                 cls_cost=dict(type='ClassificationCost', weight=1.),\n                 mask_cost=dict(type='SigmoidCost', weight=1.0),\n                 dice_cost=dict(),\n                 boundary_cost=None,\n                 topk=1):\n        self.cls_cost = build_match_cost(cls_cost)\n        self.mask_cost = build_match_cost(mask_cost)\n        self.dice_cost = build_match_cost(dice_cost)\n        if boundary_cost is not None:\n            self.boundary_cost = build_match_cost(boundary_cost)\n        else:\n            self.boundary_cost = None\n        self.topk = topk\n\n    def assign(self,\n               bbox_pred,\n               cls_pred,\n               gt_bboxes,\n               gt_labels,\n               img_meta=None,\n               gt_bboxes_ignore=None,\n               eps=1e-7):\n        \"\"\"Computes one-to-one matching based on the weighted costs.\n\n        This method assign each query prediction to a ground truth or\n        background. The `assigned_gt_inds` with -1 means don't care,\n        0 means negative sample, and positive number is the index (1-based)\n        of assigned gt.\n        The assignment is done in the following steps, the order matters.\n\n        1. assign every prediction to -1\n        2. compute the weighted costs\n        3. do Hungarian matching on CPU based on the costs\n        4. assign all to 0 (background) first, then for each matched pair\n           between predictions and gts, treat this prediction as foreground\n           and assign the corresponding gt index (plus 1) to it.\n\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_bboxes (Tensor): Ground truth boxes with unnormalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n            img_meta (dict): Meta information for current image.\n            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are\n                labelled as `ignored`. Default None.\n            eps (int | float, optional): A value added to the denominator for\n                numerical stability. Default 1e-7.\n\n        Returns:\n            :obj:`AssignResult`: The assigned result.\n        \"\"\"\n        assert gt_bboxes_ignore is None, \\\n            'Only case when gt_bboxes_ignore is None is supported.'\n        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)\n\n        # 1. assign -1 by default\n        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),\n                                              -1,\n                                              dtype=torch.long)\n        assigned_labels = bbox_pred.new_full((num_bboxes, ),\n                                             -1,\n                                             dtype=torch.long)\n        if num_gts == 0 or num_bboxes == 0:\n            # No ground truth or boxes, return empty assignment\n            if num_gts == 0:\n                # No ground truth, assign all to background\n                assigned_gt_inds[:] = 0\n            return AssignResult(\n                num_gts, assigned_gt_inds, None, labels=assigned_labels)\n\n        # 2. compute the weighted costs\n        # classification and bboxcost.\n        if self.cls_cost.weight != 0 and cls_pred is not None:\n            cls_cost = self.cls_cost(cls_pred, gt_labels)\n        else:\n            cls_cost = 0\n        if self.mask_cost.weight != 0:\n            reg_cost = self.mask_cost(bbox_pred, gt_bboxes)\n        else:\n            reg_cost = 0\n        if self.dice_cost.weight != 0:\n            dice_cost = self.dice_cost(bbox_pred, gt_bboxes)\n        else:\n            dice_cost = 0\n        if self.boundary_cost is not None and self.boundary_cost.weight != 0:\n            b_cost = self.boundary_cost(bbox_pred, gt_bboxes)\n        else:\n            b_cost = 0\n        cost = cls_cost + reg_cost + dice_cost + b_cost\n\n        # 3. do Hungarian matching on CPU using linear_sum_assignment\n        cost = cost.detach().cpu()\n        if linear_sum_assignment is None:\n            raise ImportError('Please run \"pip install scipy\" '\n                              'to install scipy first.')\n        if self.topk == 1:\n            matched_row_inds, matched_col_inds = linear_sum_assignment(cost)\n        else:\n            topk_matched_row_inds = []\n            topk_matched_col_inds = []\n            for i in range(self.topk):\n                matched_row_inds, matched_col_inds = linear_sum_assignment(\n                    cost)\n                topk_matched_row_inds.append(matched_row_inds)\n                topk_matched_col_inds.append(matched_col_inds)\n                cost[matched_row_inds] = 1e10\n            matched_row_inds = np.concatenate(topk_matched_row_inds)\n            matched_col_inds = np.concatenate(topk_matched_col_inds)\n\n        matched_row_inds = torch.from_numpy(matched_row_inds).to(\n            bbox_pred.device)\n        matched_col_inds = torch.from_numpy(matched_col_inds).to(\n            bbox_pred.device)\n\n        # 4. assign backgrounds and foregrounds\n        # assign all indices to backgrounds first\n        assigned_gt_inds[:] = 0\n        # assign foregrounds based on matching results\n        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1\n        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]\n        return AssignResult(\n            num_gts, assigned_gt_inds, None, labels=assigned_labels)\n"
  },
  {
    "path": "knet_vis/det/mask_pseudo_sampler.py",
    "content": "import torch\n\nfrom mmdet.core.bbox import BaseSampler, SamplingResult\nfrom mmdet.core.bbox.builder import BBOX_SAMPLERS\n\n\nclass MaskSamplingResult(SamplingResult):\n    \"\"\"Bbox sampling result.\n\n    Example:\n        >>> # xdoctest: +IGNORE_WANT\n        >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA\n        >>> self = SamplingResult.random(rng=10)\n        >>> print(f'self = {self}')\n        self = <SamplingResult({\n            'neg_masks': torch.Size([12, 4]),\n            'neg_inds': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12]),\n            'num_gts': 4,\n            'pos_assigned_gt_inds': tensor([], dtype=torch.int64),\n            'pos_masks': torch.Size([0, 4]),\n            'pos_inds': tensor([], dtype=torch.int64),\n            'pos_is_gt': tensor([], dtype=torch.uint8)\n        })>\n    \"\"\"\n\n    def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result,\n                 gt_flags):\n        self.pos_inds = pos_inds\n        self.neg_inds = neg_inds\n        self.pos_masks = masks[pos_inds]\n        self.neg_masks = masks[neg_inds]\n        self.pos_is_gt = gt_flags[pos_inds]\n\n        self.num_gts = gt_masks.shape[0]\n        self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1\n\n        if gt_masks.numel() == 0:\n            # hack for index error case\n            assert self.pos_assigned_gt_inds.numel() == 0\n            self.pos_gt_masks = torch.empty_like(gt_masks)\n        else:\n            self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]\n\n        if assign_result.labels is not None:\n            self.pos_gt_labels = assign_result.labels[pos_inds]\n        else:\n            self.pos_gt_labels = None\n\n    @property\n    def masks(self):\n        \"\"\"torch.Tensor: concatenated positive and negative boxes\"\"\"\n        return torch.cat([self.pos_masks, self.neg_masks])\n\n    def __nice__(self):\n        data = self.info.copy()\n        data['pos_masks'] = data.pop('pos_masks').shape\n        data['neg_masks'] = data.pop('neg_masks').shape\n        parts = [f\"'{k}': {v!r}\" for k, v in sorted(data.items())]\n        body = '    ' + ',\\n    '.join(parts)\n        return '{\\n' + body + '\\n}'\n\n    @property\n    def info(self):\n        \"\"\"Returns a dictionary of info about the object.\"\"\"\n        return {\n            'pos_inds': self.pos_inds,\n            'neg_inds': self.neg_inds,\n            'pos_masks': self.pos_masks,\n            'neg_masks': self.neg_masks,\n            'pos_is_gt': self.pos_is_gt,\n            'num_gts': self.num_gts,\n            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,\n        }\n\n\n@BBOX_SAMPLERS.register_module()\nclass MaskPseudoSampler(BaseSampler):\n    \"\"\"A pseudo sampler that does not do sampling actually.\"\"\"\n\n    def __init__(self, **kwargs):\n        pass\n\n    def _sample_pos(self, **kwargs):\n        \"\"\"Sample positive samples.\"\"\"\n        raise NotImplementedError\n\n    def _sample_neg(self, **kwargs):\n        \"\"\"Sample negative samples.\"\"\"\n        raise NotImplementedError\n\n    def sample(self, assign_result, masks, gt_masks, **kwargs):\n        \"\"\"Directly returns the positive and negative indices  of samples.\n\n        Args:\n            assign_result (:obj:`AssignResult`): Assigned results\n            masks (torch.Tensor): Bounding boxes\n            gt_masks (torch.Tensor): Ground truth boxes\n\n        Returns:\n            :obj:`SamplingResult`: sampler results\n        \"\"\"\n        pos_inds = torch.nonzero(\n            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()\n        neg_inds = torch.nonzero(\n            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()\n        gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)\n        sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks,\n                                             gt_masks, assign_result, gt_flags)\n        return sampling_result\n"
  },
  {
    "path": "knet_vis/det/semantic_fpn_wrapper.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, normal_init\nfrom mmdet.models.builder import NECKS\nfrom mmcv.cnn.bricks.transformer import build_positional_encoding\nfrom mmdet.utils import get_root_logger\n\n\n@NECKS.register_module()\nclass SemanticFPNWrapper(nn.Module):\n    \"\"\"Implementation of Semantic FPN used in Panoptic FPN.\n\n    Args:\n        in_channels ([type]): [description]\n        feat_channels ([type]): [description]\n        out_channels ([type]): [description]\n        start_level ([type]): [description]\n        end_level ([type]): [description]\n        cat_coors (bool, optional): [description]. Defaults to False.\n        fuse_by_cat (bool, optional): [description]. Defaults to False.\n        conv_cfg ([type], optional): [description]. Defaults to None.\n        norm_cfg ([type], optional): [description]. Defaults to None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 feat_channels,\n                 out_channels,\n                 start_level,\n                 end_level,\n                 cat_coors=False,\n                 positional_encoding=None,\n                 cat_coors_level=3,\n                 fuse_by_cat=False,\n                 return_list=False,\n                 upsample_times=3,\n                 with_pred=True,\n                 num_aux_convs=0,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 out_act_cfg=dict(type='ReLU'),\n                 conv_cfg=None,\n                 norm_cfg=None):\n        super(SemanticFPNWrapper, self).__init__()\n\n        self.in_channels = in_channels\n        self.feat_channels = feat_channels\n        self.start_level = start_level\n        self.end_level = end_level\n        assert start_level >= 0 and end_level >= start_level\n        self.out_channels = out_channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.cat_coors = cat_coors\n        self.cat_coors_level = cat_coors_level\n        self.fuse_by_cat = fuse_by_cat\n        self.return_list = return_list\n        self.upsample_times = upsample_times\n        self.with_pred = with_pred\n        if positional_encoding is not None:\n            self.positional_encoding = build_positional_encoding(\n                positional_encoding)\n        else:\n            self.positional_encoding = None\n\n        self.convs_all_levels = nn.ModuleList()\n        for i in range(self.start_level, self.end_level + 1):\n            convs_per_level = nn.Sequential()\n            if i == 0:\n                if i == self.cat_coors_level and self.cat_coors:\n                    chn = self.in_channels + 2\n                else:\n                    chn = self.in_channels\n                if upsample_times == self.end_level - i:\n                    one_conv = ConvModule(\n                        chn,\n                        self.feat_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg,\n                        inplace=False)\n                    convs_per_level.add_module('conv' + str(i), one_conv)\n                else:\n                    for i in range(self.end_level - upsample_times):\n                        one_conv = ConvModule(\n                            chn,\n                            self.feat_channels,\n                            3,\n                            padding=1,\n                            stride=2,\n                            conv_cfg=self.conv_cfg,\n                            norm_cfg=self.norm_cfg,\n                            act_cfg=self.act_cfg,\n                            inplace=False)\n                        convs_per_level.add_module('conv' + str(i), one_conv)\n                self.convs_all_levels.append(convs_per_level)\n                continue\n\n            for j in range(i):\n                if j == 0:\n                    if i == self.cat_coors_level and self.cat_coors:\n                        chn = self.in_channels + 2\n                    else:\n                        chn = self.in_channels\n                    one_conv = ConvModule(\n                        chn,\n                        self.feat_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg,\n                        inplace=False)\n                    convs_per_level.add_module('conv' + str(j), one_conv)\n                    if j < upsample_times - (self.end_level - i):\n                        one_upsample = nn.Upsample(\n                            scale_factor=2,\n                            mode='bilinear',\n                            align_corners=False)\n                        convs_per_level.add_module('upsample' + str(j),\n                                                   one_upsample)\n                    continue\n\n                one_conv = ConvModule(\n                    self.feat_channels,\n                    self.feat_channels,\n                    3,\n                    padding=1,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg,\n                    inplace=False)\n                convs_per_level.add_module('conv' + str(j), one_conv)\n                if j < upsample_times - (self.end_level - i):\n                    one_upsample = nn.Upsample(\n                        scale_factor=2, mode='bilinear', align_corners=False)\n                    convs_per_level.add_module('upsample' + str(j),\n                                               one_upsample)\n\n            self.convs_all_levels.append(convs_per_level)\n\n        if fuse_by_cat:\n            in_channels = self.feat_channels * len(self.convs_all_levels)\n        else:\n            in_channels = self.feat_channels\n\n        if self.with_pred:\n            self.conv_pred = ConvModule(\n                in_channels,\n                self.out_channels,\n                1,\n                padding=0,\n                conv_cfg=self.conv_cfg,\n                act_cfg=out_act_cfg,\n                norm_cfg=self.norm_cfg)\n\n        self.num_aux_convs = num_aux_convs\n        self.aux_convs = nn.ModuleList()\n        for i in range(num_aux_convs):\n            self.aux_convs.append(\n                ConvModule(\n                    in_channels,\n                    self.out_channels,\n                    1,\n                    padding=0,\n                    conv_cfg=self.conv_cfg,\n                    act_cfg=out_act_cfg,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        logger = get_root_logger()\n        logger.info('Use normal intialization for semantic FPN')\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                normal_init(m, std=0.01)\n\n    def generate_coord(self, input_feat):\n        x_range = torch.linspace(\n            -1, 1, input_feat.shape[-1], device=input_feat.device)\n        y_range = torch.linspace(\n            -1, 1, input_feat.shape[-2], device=input_feat.device)\n        y, x = torch.meshgrid(y_range, x_range)\n        y = y.expand([input_feat.shape[0], 1, -1, -1])\n        x = x.expand([input_feat.shape[0], 1, -1, -1])\n        coord_feat = torch.cat([x, y], 1)\n        return coord_feat\n\n    def forward(self, inputs):\n        mlvl_feats = []\n        for i in range(self.start_level, self.end_level + 1):\n            input_p = inputs[i]\n            if i == self.cat_coors_level:\n                if self.positional_encoding is not None:\n                    ignore_mask = input_p.new_zeros(\n                        (input_p.shape[0], input_p.shape[-2],\n                         input_p.shape[-1]),\n                        dtype=torch.bool)\n                    positional_encoding = self.positional_encoding(ignore_mask)\n                    input_p = input_p + positional_encoding\n                if self.cat_coors:\n                    coord_feat = self.generate_coord(input_p)\n                    input_p = torch.cat([input_p, coord_feat], 1)\n\n            mlvl_feats.append(self.convs_all_levels[i](input_p))\n\n        if self.fuse_by_cat:\n            feature_add_all_level = torch.cat(mlvl_feats, dim=1)\n        else:\n            feature_add_all_level = sum(mlvl_feats)\n\n        if self.with_pred:\n            out = self.conv_pred(feature_add_all_level)\n        else:\n            out = feature_add_all_level\n\n        if self.num_aux_convs > 0:\n            outs = [out]\n            for conv in self.aux_convs:\n                outs.append(conv(feature_add_all_level))\n            return outs\n\n        if self.return_list:\n            return [out]\n        else:\n            return out\n"
  },
  {
    "path": "knet_vis/det/utils.py",
    "content": "import torch\n\n\ndef sem2ins_masks(gt_sem_seg,\n                  num_thing_classes=80):\n    \"\"\"Convert semantic segmentation mask to binary masks\n\n    Args:\n        gt_sem_seg (torch.Tensor): Semantic masks to be converted.\n            [0, num_thing_classes-1] is the classes of things,\n            [num_thing_classes:] is the classes of stuff.\n        num_thing_classes (int, optional): Number of thing classes.\n            Defaults to 80.\n\n    Returns:\n        tuple[torch.Tensor]: (mask_labels, bin_masks).\n            Mask labels and binary masks of stuff classes.\n    \"\"\"\n    # gt_sem_seg is zero-started, where zero indicates the first class\n    # since mmdet>=2.17.0, see more discussion in\n    # https://mmdetection.readthedocs.io/en/latest/conventions.html#coco-panoptic-dataset  # noqa\n    classes = torch.unique(gt_sem_seg)\n    # classes ranges from 0 - N-1, where the class IDs in\n    # [0, num_thing_classes - 1] are IDs of thing classes\n    masks = []\n    labels = []\n\n    for i in classes:\n        # skip ignore class 255 and \"thing classes\" in semantic seg\n        if i == 255 or i < num_thing_classes:\n            continue\n        labels.append(i)\n        masks.append(gt_sem_seg == i)\n\n    if len(labels) > 0:\n        labels = torch.stack(labels)\n        masks = torch.cat(masks)\n    else:\n        labels = gt_sem_seg.new_zeros(size=[0])\n        masks = gt_sem_seg.new_zeros(\n            size=[0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1]])\n    return labels.long(), masks.float()\n"
  },
  {
    "path": "knet_vis/kernel_updator.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import build_activation_layer, build_norm_layer\nfrom mmcv.cnn.bricks.transformer import TRANSFORMER_LAYER\n\n\n@TRANSFORMER_LAYER.register_module()\nclass KernelUpdator(nn.Module):\n\n    def __init__(self,\n                 in_channels=256,\n                 feat_channels=64,\n                 out_channels=None,\n                 input_feat_shape=3,\n                 gate_sigmoid=True,\n                 gate_norm_act=False,\n                 activate_out=False,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 norm_cfg=dict(type='LN')):\n        super(KernelUpdator, self).__init__()\n        self.in_channels = in_channels\n        self.feat_channels = feat_channels\n        self.out_channels_raw = out_channels\n        self.gate_sigmoid = gate_sigmoid\n        self.gate_norm_act = gate_norm_act\n        self.activate_out = activate_out\n        if isinstance(input_feat_shape, int):\n            input_feat_shape = [input_feat_shape] * 2\n        self.input_feat_shape = input_feat_shape\n        self.act_cfg = act_cfg\n        self.norm_cfg = norm_cfg\n        self.out_channels = out_channels if out_channels else in_channels\n\n        self.num_params_in = self.feat_channels\n        self.num_params_out = self.feat_channels\n        self.dynamic_layer = nn.Linear(\n            self.in_channels, self.num_params_in + self.num_params_out)\n        self.input_layer = nn.Linear(self.in_channels,\n                                     self.num_params_in + self.num_params_out,\n                                     1)\n        self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)\n        self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)\n        if self.gate_norm_act:\n            self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]\n\n        self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]\n\n        self.activation = build_activation_layer(act_cfg)\n\n        self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)\n        self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]\n\n    def forward(self, update_feature, input_feature):\n        update_feature = update_feature.reshape(-1, self.in_channels)\n        num_proposals = update_feature.size(0)\n        parameters = self.dynamic_layer(update_feature)\n        param_in = parameters[:, :self.num_params_in].view(\n            -1, self.feat_channels)\n        param_out = parameters[:, -self.num_params_out:].view(\n            -1, self.feat_channels)\n\n        input_feats = self.input_layer(\n            input_feature.reshape(num_proposals, -1, self.feat_channels))\n        input_in = input_feats[..., :self.num_params_in]\n        input_out = input_feats[..., -self.num_params_out:]\n\n        gate_feats = input_in * param_in.unsqueeze(-2)\n        if self.gate_norm_act:\n            gate_feats = self.activation(self.gate_norm(gate_feats))\n\n        input_gate = self.input_norm_in(self.input_gate(gate_feats))\n        update_gate = self.norm_in(self.update_gate(gate_feats))\n        if self.gate_sigmoid:\n            input_gate = input_gate.sigmoid()\n            update_gate = update_gate.sigmoid()\n        param_out = self.norm_out(param_out)\n        input_out = self.input_norm_out(input_out)\n\n        if self.activate_out:\n            param_out = self.activation(param_out)\n            input_out = self.activation(input_out)\n\n        # param_out has shape (batch_size, feat_channels, out_channels)\n        features = update_gate * param_out.unsqueeze(\n            -2) + input_gate * input_out\n\n        features = self.fc_layer(features)\n        features = self.fc_norm(features)\n        features = self.activation(features)\n\n        return features\n"
  },
  {
    "path": "knet_vis/tracker/__init__.py",
    "content": ""
  },
  {
    "path": "knet_vis/tracker/kernel_frame_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, normal_init)\nfrom mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean\nfrom mmdet.models.builder import HEADS, build_loss, build_neck\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\n\n@HEADS.register_module()\nclass ConvKernelHeadVolume(nn.Module):\n    def __init__(self,\n                 num_proposals=100,\n                 in_channels=256,\n                 out_channels=256,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_seg_convs=1,\n                 num_loc_convs=1,\n                 att_dropout=False,\n                 localization_fpn=None,\n                 conv_kernel_size=1,\n                 norm_cfg=dict(type='GN', num_groups=32),\n                 semantic_fpn=True,\n                 train_cfg=None,\n                 num_classes=80,\n                 xavier_init_kernel=False,\n                 kernel_init_std=0.01,\n                 use_binary=False,\n                 proposal_feats_with_obj=False,\n                 loss_mask=None,\n                 loss_seg=None,\n                 loss_cls=None,\n                 loss_dice=None,\n                 loss_rank=None,\n                 feat_downsample_stride=1,\n                 feat_refine_stride=1,\n                 feat_refine=True,\n                 with_embed=False,\n                 feat_embed_only=False,\n                 conv_normal_init=False,\n                 mask_out_stride=4,\n                 hard_target=False,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cat_stuff_mask=False,\n                 **kwargs):\n        super().__init__()\n        self.num_proposals = num_proposals\n        self.num_cls_fcs = num_cls_fcs\n        self.train_cfg = train_cfg\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_classes = num_classes\n        self.proposal_feats_with_obj = proposal_feats_with_obj\n        self.sampling = False\n        self.localization_fpn = build_neck(localization_fpn)\n        self.semantic_fpn = semantic_fpn\n        self.norm_cfg = norm_cfg\n        self.num_heads = num_heads\n        self.att_dropout = att_dropout\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.conv_kernel_size = conv_kernel_size\n        self.xavier_init_kernel = xavier_init_kernel\n        self.kernel_init_std = kernel_init_std\n        self.feat_downsample_stride = feat_downsample_stride\n        self.feat_refine_stride = feat_refine_stride\n        self.conv_normal_init = conv_normal_init\n        self.feat_refine = feat_refine\n        self.with_embed = with_embed\n        self.feat_embed_only = feat_embed_only\n        self.num_loc_convs = num_loc_convs\n        self.num_seg_convs = num_seg_convs\n        self.use_binary = use_binary\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n        self.cat_stuff_mask = cat_stuff_mask\n\n        if loss_mask is not None:\n            self.loss_mask = build_loss(loss_mask)\n        else:\n            self.loss_mask = loss_mask\n\n        if loss_dice is not None:\n            self.loss_dice = build_loss(loss_dice)\n        else:\n            self.loss_dice = loss_dice\n\n        if loss_seg is not None:\n            self.loss_seg = build_loss(loss_seg)\n        else:\n            self.loss_seg = loss_seg\n        if loss_cls is not None:\n            self.loss_cls = build_loss(loss_cls)\n        else:\n            self.loss_cls = loss_cls\n\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        if self.train_cfg:\n            self.assigner = build_assigner(self.train_cfg.assigner)\n            # use PseudoSampler when sampling is False\n            if self.sampling and hasattr(self.train_cfg, 'sampler'):\n                sampler_cfg = self.train_cfg.sampler\n            else:\n                sampler_cfg = dict(type='MaskPseudoSampler')\n            self.sampler = build_sampler(sampler_cfg, context=self)\n        self._init_layers()\n\n    def _init_layers(self):\n        \"\"\"Initialize a sparse set of proposal boxes and proposal features.\"\"\"\n        self.init_kernels = nn.Conv2d(\n            self.out_channels,\n            self.num_proposals,\n            self.conv_kernel_size,\n            padding=int(self.conv_kernel_size // 2),\n            bias=False)\n\n        if self.semantic_fpn:\n            if self.loss_seg.use_sigmoid:\n                self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes,\n                                          1)\n            else:\n                self.conv_seg = nn.Conv2d(self.out_channels,\n                                          self.num_classes + 1, 1)\n\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            self.ins_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,\n                padding=1,\n                norm_cfg=self.norm_cfg)\n            self.seg_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,\n                padding=1,\n                norm_cfg=self.norm_cfg)\n\n        self.loc_convs = nn.ModuleList()\n        for i in range(self.num_loc_convs):\n            self.loc_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n        self.seg_convs = nn.ModuleList()\n        for i in range(self.num_seg_convs):\n            self.seg_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        self.localization_fpn.init_weights()\n\n        if self.feat_downsample_stride > 1 and self.conv_normal_init:\n            logger = get_root_logger()\n            logger.info('Initialize convs in KPN head by normal std 0.01')\n            for conv in [self.loc_convs, self.seg_convs]:\n                for m in conv.modules():\n                    if isinstance(m, nn.Conv2d):\n                        normal_init(m, std=0.01)\n\n        if self.semantic_fpn:\n            bias_seg = bias_init_with_prob(0.01)\n            if self.loss_seg.use_sigmoid:\n                normal_init(self.conv_seg, std=0.01, bias=bias_seg)\n            else:\n                normal_init(self.conv_seg, mean=0, std=0.01)\n        if self.xavier_init_kernel:\n            logger = get_root_logger()\n            logger.info('Initialize kernels by xavier uniform')\n            nn.init.xavier_uniform_(self.init_kernels.weight)\n        else:\n            logger = get_root_logger()\n            logger.info(\n                f'Initialize kernels by normal std: {self.kernel_init_std}')\n            normal_init(self.init_kernels, mean=0, std=self.kernel_init_std)\n\n    def _decode_init_proposals(self, img, img_metas, ref_img_metas):\n        num_imgs = len(img_metas)\n        num_frames = len(ref_img_metas[0])\n\n        if self.localization_fpn.__class__.__name__.endswith('3D'):\n            localization_feats = self.localization_fpn(img, num_imgs, num_frames)\n        else:\n            localization_feats = self.localization_fpn(img)\n        if isinstance(localization_feats, list):\n            loc_feats = localization_feats[0]\n        else:\n            loc_feats = localization_feats\n        for conv in self.loc_convs:\n            loc_feats = conv(loc_feats)\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            loc_feats = self.ins_downsample(loc_feats)\n        mask_preds = self.init_kernels(loc_feats)\n\n        if self.semantic_fpn:\n            if isinstance(localization_feats, list):\n                semantic_feats = localization_feats[1]\n            else:\n                semantic_feats = localization_feats\n            for conv in self.seg_convs:\n                semantic_feats = conv(semantic_feats)\n            if self.feat_downsample_stride > 1 and self.feat_refine:\n                semantic_feats = self.seg_downsample(semantic_feats)\n        else:\n            semantic_feats = None\n\n        if semantic_feats is not None:\n            seg_preds = self.conv_seg(semantic_feats)\n        else:\n            seg_preds = None\n\n        proposal_feats = self.init_kernels.weight.clone()\n        proposal_feats = proposal_feats[None].expand(num_imgs * num_frames, *proposal_feats.size())\n\n        if semantic_feats is not None:\n            x_feats = semantic_feats + loc_feats\n        else:\n            x_feats = loc_feats\n\n        if self.proposal_feats_with_obj:\n            sigmoid_masks = mask_preds.sigmoid()\n            nonzero_inds = sigmoid_masks > 0.5\n            if self.use_binary:\n                sigmoid_masks = nonzero_inds.float()\n            else:\n                sigmoid_masks = nonzero_inds.float() * sigmoid_masks\n            obj_feats = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x_feats)\n\n        cls_scores = None\n\n        if self.proposal_feats_with_obj:\n            proposal_feats = proposal_feats + obj_feats.view(\n                num_imgs * num_frames, self.num_proposals, self.out_channels, 1, 1)\n\n        if self.cat_stuff_mask and not self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames, *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      ref_img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_instance_ids=None,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n        \"\"\"Forward function in training stage.\"\"\"\n        assert gt_instance_ids is not None\n        num_imgs = len(img_metas)\n        num_frames = len(ref_img_metas[0])\n        results = self._decode_init_proposals(img, img_metas, ref_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results\n        if self.feat_downsample_stride > 1:\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=self.feat_downsample_stride,\n                mode='bilinear',\n                align_corners=False)\n            if seg_preds is not None:\n                scaled_seg_preds = F.interpolate(\n                    seg_preds,\n                    scale_factor=self.feat_downsample_stride,\n                    mode='bilinear',\n                    align_corners=False)\n        else:\n            scaled_mask_preds = mask_preds\n            scaled_seg_preds = seg_preds\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        sampling_results = []\n        if cls_scores is None:\n            detached_cls_scores = [None] * num_imgs\n        else:\n            detached_cls_scores = cls_scores.detach()\n\n        scaled_mask_preds = scaled_mask_preds.reshape((num_imgs, num_frames, *scaled_mask_preds.size()[1:]))\n\n        num_cls = scaled_seg_preds.size(1)\n        _h, _w = scaled_mask_preds.size()[-2:]\n        scaled_seg_preds = scaled_seg_preds.reshape((num_imgs, num_frames, *scaled_seg_preds.size()[1:]))\n        scaled_seg_preds = torch.einsum('nfshw->nsfhw', scaled_seg_preds).reshape((num_imgs, num_cls, num_frames * _h, _w))\n\n        pred_masks_concat = []\n        for i in range(num_imgs):\n            assign_result, gt_masks_match = self.assigner.assign(scaled_mask_preds[i].detach(),\n                                                 detached_cls_scores[i],\n                                                 gt_masks[i], gt_labels[i],\n                                                 gt_instance_ids[i])\n            num_bboxes = scaled_mask_preds.size(2)\n            h, w = scaled_mask_preds.shape[-2:]\n            pred_masks_match = torch.einsum('fqhw->qfhw', scaled_mask_preds[i]).reshape((num_bboxes, -1, w))\n            sampling_result = self.sampler.sample(assign_result,\n                                                  pred_masks_match,\n                                                  gt_masks_match)\n            sampling_results.append(sampling_result)\n            pred_masks_concat.append(pred_masks_match)\n        pred_masks_concat = torch.stack(pred_masks_concat)\n\n        mask_targets = self.get_targets(\n            sampling_results,\n            self.train_cfg,\n            True,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls)\n\n        losses = self.loss(pred_masks_concat, cls_scores, scaled_seg_preds, None, *mask_targets)\n\n        if self.cat_stuff_mask and self.training:\n            mask_preds = torch.cat([mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames, *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return losses, proposal_feats, x_feats, mask_preds, cls_scores\n\n    def loss(self,\n             mask_pred,\n             cls_scores,\n             seg_preds,\n             proposal_feats,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             seg_targets,\n             reduction_override=None,\n             **kwargs):\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n\n        if cls_scores is not None:\n            num_pos = pos_inds.sum().float()\n            avg_factor = reduce_mean(num_pos)\n            assert mask_pred.shape[0] == cls_scores.shape[0]\n            assert mask_pred.shape[1] == cls_scores.shape[1]\n            losses['loss_rpn_cls'] = self.loss_cls(\n                cls_scores.view(num_preds, -1),\n                labels,\n                label_weights,\n                avg_factor=avg_factor,\n                reduction_override=reduction_override)\n            losses['rpn_pos_acc'] = accuracy(\n                cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds])\n\n        bool_pos_inds = pos_inds.type(torch.bool)\n        # 0~self.num_classes-1 are FG, self.num_classes is BG\n        # do not perform bounding box regression for BG anymore.\n        H, W = mask_pred.shape[-2:]\n        if pos_inds.any():\n            pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds]\n            pos_mask_targets = mask_targets[bool_pos_inds]\n            losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n            losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n            if self.loss_rank is not None:\n                batch_size = mask_pred.size(0)\n                rank_target = mask_targets.new_full((batch_size, H, W),\n                                                    self.ignore_label,\n                                                    dtype=torch.long)\n                rank_inds = pos_inds.view(batch_size,\n                                          -1).nonzero(as_tuple=False)\n                batch_mask_targets = mask_targets.view(batch_size, -1, H,\n                                                       W).bool()\n                for i in range(batch_size):\n                    curr_inds = (rank_inds[:, 0] == i)\n                    curr_rank = rank_inds[:, 1][curr_inds]\n                    for j in curr_rank:\n                        rank_target[i][batch_mask_targets[i][j]] = j\n                losses['loss_rpn_rank'] = self.loss_rank(\n                    mask_pred, rank_target, ignore_index=self.ignore_label)\n\n        else:\n            losses['loss_rpn_mask'] = mask_pred.sum() * 0\n            losses['loss_rpn_dice'] = mask_pred.sum() * 0\n            if self.loss_rank is not None:\n                losses['loss_rank'] = mask_pred.sum() * 0\n\n        if seg_preds is not None:\n            if self.loss_seg.use_sigmoid:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(\n                    -1, cls_channel,\n                    H * W).permute(0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                num_dense_pos = (flatten_seg_target >= 0) & (\n                    flatten_seg_target < bg_class_ind)\n                num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0)\n                losses['loss_rpn_seg'] = self.loss_seg(\n                    flatten_seg,\n                    flatten_seg_target,\n                    avg_factor=num_dense_pos)\n            else:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(\n                    0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                losses['loss_rpn_seg'] = self.loss_seg(flatten_seg,\n                                                       flatten_seg_target)\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros(num_samples)\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        seg_targets = pos_mask.new_full((H, W),\n                                        self.num_classes,\n                                        dtype=torch.long)\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            gt_sem_seg = gt_sem_seg.bool()\n            for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls):\n                seg_targets[sem_mask] = sem_cls.long()\n\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            mask_targets[pos_inds, ...] = pos_gt_mask\n            mask_weights[pos_inds, ...] = 1\n            for i in range(num_pos):\n                seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i]\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def get_targets(self,\n                    sampling_results,\n                    rpn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        num_imgs = len(sampling_results)\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * num_imgs\n            gt_sem_cls = [None] * num_imgs\n        results = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rpn_train_cfg)\n        (labels, label_weights, mask_targets, mask_weights,\n         seg_targets) = results\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n            seg_targets = torch.stack(seg_targets, 0)\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def simple_test_rpn(self, img, img_metas, ref_img_metas):\n        \"\"\"Forward function in testing stage.\"\"\"\n        return self._decode_init_proposals(img, img_metas, ref_img_metas)\n\n    def forward_dummy(self, img, img_metas, ref_img_metas):\n        \"\"\"Dummy forward function.\n\n        Used in flops calculation.\n        \"\"\"\n        return self._decode_init_proposals(img, img_metas,ref_img_metas)\n"
  },
  {
    "path": "knet_vis/tracker/kernel_frame_iter_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import build_norm_layer\nfrom mmcv.cnn.bricks.transformer import MultiheadAttention, FFN\n\nfrom mmdet.core import build_assigner, build_sampler\nfrom mmdet.models.builder import HEADS, build_head\nfrom mmdet.models.roi_heads import BaseRoIHead\n\nfrom mmdet.utils import get_root_logger\n\n@HEADS.register_module()\nclass KernelFrameIterHeadVideo(BaseRoIHead):\n    def __init__(self,\n                 mask_head=None,\n                 with_mask_init=False,\n                 num_stages=3,\n                 stage_loss_weights=(1, 1, 1),\n                 proposal_feature_channel=256,\n                 assign_stages=5,\n                 num_proposals=100,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 query_merge_method='mean',\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None,\n                 init_cfg=None,\n                 **kwargs):\n        assert len(stage_loss_weights) == num_stages\n        self.num_stages = num_stages\n        self.stage_loss_weights = stage_loss_weights\n        self.assign_stages = assign_stages\n        self.num_proposals = num_proposals\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.query_merge_method = query_merge_method\n        self.proposal_feature_channel = proposal_feature_channel\n        super().__init__(\n            mask_head=mask_head,\n            train_cfg=train_cfg,\n            test_cfg=test_cfg,\n            init_cfg=init_cfg,\n            **kwargs\n        )\n        if self.query_merge_method == 'attention':\n            self.init_query = nn.Embedding(self.num_proposals, self.proposal_feature_channel)\n            _num_head = 8\n            _drop_out = 0.\n            self.query_merge_attn = MultiheadAttention(self.proposal_feature_channel, _num_head, _drop_out, batch_first=True)\n            self.query_merge_norm = build_norm_layer(dict(type='LN'), self.proposal_feature_channel)[1]\n            self.query_merge_ffn = FFN(\n                self.proposal_feature_channel,\n                self.proposal_feature_channel * 8,\n                num_ffn_fcs=2,\n                act_cfg=dict(type='ReLU', inplace=True),\n                ffn_drop=0.)\n            self.query_merge_ffn_norm = build_norm_layer(dict(type='LN'), self.proposal_feature_channel)[1]\n        elif self.query_merge_method == 'attention_pos':\n            self.init_query = nn.Embedding(self.num_proposals, self.proposal_feature_channel)\n            self.query_pos = nn.Embedding(self.num_proposals, self.proposal_feature_channel)\n            _num_head = 8\n            _drop_out = 0.\n            self.query_merge_attn = MultiheadAttention(self.proposal_feature_channel, _num_head, _drop_out,\n                                                       batch_first=True)\n            self.query_merge_norm = build_norm_layer(dict(type='LN'), self.proposal_feature_channel)[1]\n            self.query_merge_ffn = FFN(\n                self.proposal_feature_channel,\n                self.proposal_feature_channel * 8,\n                num_ffn_fcs=2,\n                act_cfg=dict(type='ReLU', inplace=True),\n                ffn_drop=0.)\n            self.query_merge_ffn_norm = build_norm_layer(dict(type='LN'), self.proposal_feature_channel)[1]\n\n        self.with_mask_init = with_mask_init\n        if self.with_mask_init:\n            self.fc_mask = nn.Linear(proposal_feature_channel, proposal_feature_channel)\n\n        self.logger = get_root_logger()\n\n    def init_mask_head(self, bbox_roi_extractor=None, mask_head=None):\n        assert bbox_roi_extractor is None\n        self.mask_head = nn.ModuleList()\n        if not isinstance(mask_head, list):\n            mask_head = [mask_head for _ in range(self.num_stages)]\n        assert len(mask_head) == self.num_stages\n        for idx, head in enumerate(mask_head):\n            head.update(with_cls=(idx < self.assign_stages))\n            self.mask_head.append(build_head(head))\n\n    def init_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler for each stage.\"\"\"\n        self.mask_assigner = []\n        self.mask_sampler = []\n        if self.train_cfg is not None:\n            for i in range(self.num_stages):\n                self.mask_assigner.append(\n                    build_assigner(self.train_cfg.assigner))\n                self.current_stage = i\n                self.mask_sampler.append(\n                    build_sampler(self.train_cfg.sampler, context=self))\n\n    def init_bbox_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize box head and box roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of box roi extractor.\n            mask_head (dict): Config of box in box head.\n        \"\"\"\n        raise NotImplementedError\n\n    def _mask_forward(self, stage, x, object_feats, mask_preds):\n        mask_head = self.mask_head[stage]\n        cls_score, mask_preds, object_feats = mask_head(\n            x, object_feats, mask_preds, img_metas=None,\n            pos=self.query_pos.weight if self.query_merge_method == 'attention_pos' else None)\n        if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1 or self.training):\n            scaled_mask_preds = [\n                F.interpolate(\n                    mask_preds[i],\n                    scale_factor=mask_head.mask_upsample_stride,\n                    align_corners=False,\n                    mode='bilinear'\n                ) for i in range(mask_preds.size(0))\n            ]\n            scaled_mask_preds = torch.stack(scaled_mask_preds)\n        else:\n            scaled_mask_preds = mask_preds\n\n        mask_results = dict(\n            cls_score=cls_score,\n            mask_preds=mask_preds,\n            scaled_mask_preds=scaled_mask_preds,\n            object_feats=object_feats\n        )\n        return mask_results\n\n    def _query_fusion(self, obj_feats, num_imgs, num_frames):\n        if self.query_merge_method == 'mean':\n            object_feats = obj_feats.mean(1)\n        elif self.query_merge_method == 'attention':\n            assert obj_feats.size()[-2:] == (1,1), \"Only supporting kernel size = 1\"\n            obj_feats = obj_feats.reshape((num_imgs, num_frames * self.num_proposals, self.proposal_feature_channel))\n            init_query = self.init_query.weight.expand(num_imgs, *self.init_query.weight.size())\n            obj_feats = self.query_merge_attn(query=init_query, key=obj_feats, value=obj_feats)\n            obj_feats = self.query_merge_norm(obj_feats)\n            object_feats = self.query_merge_ffn_norm(self.query_merge_ffn(obj_feats))\n            object_feats = object_feats[..., None, None]\n        elif self.query_merge_method == 'attention_pos':\n            assert obj_feats.size()[-2:] == (1, 1), \"Only supporting kernel size = 1\"\n            obj_feats = obj_feats.reshape((num_imgs, num_frames * self.num_proposals, self.proposal_feature_channel))\n            init_query = self.init_query.weight.expand(num_imgs, *self.init_query.weight.size())\n            query_pos = self.query_pos.weight.repeat(num_imgs, 1, 1)\n            key_pos = query_pos.repeat(1, num_frames, 1)\n            obj_feats = self.query_merge_attn(query=init_query, key=obj_feats, value=obj_feats,\n                                              query_pos=query_pos, key_pos=key_pos)\n            obj_feats = self.query_merge_norm(obj_feats)\n            object_feats = self.query_merge_ffn_norm(self.query_merge_ffn(obj_feats))\n            object_feats = object_feats[..., None, None]\n\n        return object_feats\n\n    def _mask_init(self, object_feats, x_feats, num_imgs):\n        assert object_feats.size()[-2:] == (1, 1), \"Only supporting kernel size = 1\"\n        object_feats = object_feats.flatten(-3, -1) # BNCKK -> BNC\n        mask_feat = self.fc_mask(object_feats)[...,None, None]\n        mask_preds = []\n        for i in range(num_imgs):\n            mask_preds.append(\n                F.conv2d(\n                    x_feats[i],\n                    mask_feat[i],\n                    padding=0)\n            )\n\n        mask_preds = torch.stack(mask_preds, dim=0)\n\n        return mask_preds\n\n    def forward_train(self,\n                      x,\n                      ref_img_metas,\n                      cls_scores,\n                      masks,\n                      obj_feats,\n                      ref_gt_masks,\n                      ref_gt_labels,\n                      ref_gt_instance_ids,\n                      **kwargs):\n        num_imgs = len(ref_img_metas)\n        num_frames = len(ref_img_metas[0])\n        if len(obj_feats.size()) == 6:\n            object_feats = self._query_fusion(obj_feats, num_imgs, num_frames)\n        else:\n            object_feats = obj_feats\n\n        all_stage_loss = {}\n        if self.with_mask_init:\n            mask_preds = self._mask_init(object_feats, x, num_imgs)\n            assert self.training\n            if self.mask_head[0].mask_upsample_stride > 1:\n                scaled_mask_preds = [\n                    F.interpolate(\n                        mask_preds[i],\n                        scale_factor=self.mask_head[0].mask_upsample_stride,\n                        align_corners=False,\n                        mode='bilinear'\n                    ) for i in range(mask_preds.size(0))\n                ]\n                scaled_mask_preds = torch.stack(scaled_mask_preds)\n            else:\n                scaled_mask_preds = mask_preds\n            _gt_masks_matches = []\n            _assign_results = []\n            _sampling_results = []\n            _pred_masks_concat = []\n            for i in range(num_imgs):\n                mask_for_assign = scaled_mask_preds[i][:self.num_proposals].detach()\n                cls_for_assign = None\n                assign_result, gt_masks_match = self.mask_assigner[0].assign(\n                    mask_for_assign, cls_for_assign, ref_gt_masks[i], ref_gt_labels[i], ref_gt_instance_ids[i])\n                _gt_masks_matches.append(gt_masks_match)\n                _assign_results.append(assign_result)\n                num_bboxes = scaled_mask_preds.size(2)\n                h, w = scaled_mask_preds.shape[-2:]\n                pred_masks_match = torch.einsum('fqhw->qfhw', scaled_mask_preds[i]).reshape((num_bboxes, -1, w))\n                sampling_result = self.mask_sampler[0].sample(\n                    assign_result, pred_masks_match, gt_masks_match)\n                _sampling_results.append(sampling_result)\n                _pred_masks_concat.append(pred_masks_match)\n            pred_masks_concat = torch.stack(_pred_masks_concat)\n            mask_targets = self.mask_head[0].get_targets(\n                _sampling_results,\n                self.train_cfg,\n                True,\n                gt_sem_seg=None,\n                gt_sem_cls=None\n            )\n\n            single_stage_loss = self.mask_head[0].loss(\n                object_feats,\n                None,\n                pred_masks_concat,\n                *mask_targets)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f'tracker_init_{key}'] = value * self.stage_loss_weights[0]\n        else:\n            mask_preds = masks\n\n\n        assign_results = []\n        for stage in range(self.num_stages):\n            if stage == self.assign_stages:\n                object_feats = object_feats[:, None].repeat(1, num_frames, 1, 1, 1, 1)\n            mask_results = self._mask_forward(stage, x, object_feats, mask_preds)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score']\n            object_feats = mask_results['object_feats']\n\n            prev_mask_preds = scaled_mask_preds.detach()\n            prev_cls_score = cls_score.detach() if cls_score is not None else None\n\n            sampling_results = []\n            pred_masks_concat = []\n            if stage < self.assign_stages:\n                assign_results = []\n                gt_masks_matches = []\n            for i in range(num_imgs):\n                if stage < self.assign_stages:\n                    mask_for_assign = prev_mask_preds[i][:, :self.num_proposals]\n                    if prev_cls_score is not None:\n                        cls_for_assign = prev_cls_score[i][:self.num_proposals, :self.num_thing_classes]\n                    else:\n                        cls_for_assign = None\n                    assign_result, gt_masks_match = self.mask_assigner[stage].assign(\n                        mask_for_assign, cls_for_assign, ref_gt_masks[i], ref_gt_labels[i], ref_gt_instance_ids[i])\n                    gt_masks_matches.append(gt_masks_match)\n                    assign_results.append(assign_result)\n                num_bboxes = scaled_mask_preds.size(2)\n                h, w = scaled_mask_preds.shape[-2:]\n                pred_masks_match = torch.einsum('fqhw->qfhw', scaled_mask_preds[i]).reshape((num_bboxes, -1, w))\n                sampling_result = self.mask_sampler[stage].sample(\n                    assign_results[i], pred_masks_match, gt_masks_matches[i])\n                sampling_results.append(sampling_result)\n                pred_masks_concat.append(pred_masks_match)\n            pred_masks_concat = torch.stack(pred_masks_concat)\n            mask_targets = self.mask_head[stage].get_targets(\n                sampling_results,\n                self.train_cfg,\n                True,\n                gt_sem_seg=None,\n                gt_sem_cls=None\n            )\n\n            single_stage_loss = self.mask_head[stage].loss(\n                object_feats,\n                cls_score,\n                pred_masks_concat,\n                *mask_targets)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f'tracker_s{stage}_{key}'] = value * self.stage_loss_weights[stage]\n\n        features = {\n            \"obj_feats\": object_feats,\n            \"x_feats\": x,\n            \"cls_scores\": cls_score,\n            \"masks\": mask_preds,\n        }\n        return all_stage_loss, features\n\n    def simple_test(self,\n                    x,\n                    img_metas,\n                    ref_img_metas,\n                    cls_scores,\n                    masks,\n                    obj_feats,\n                    **kwargs):\n        num_imgs = len(ref_img_metas)\n        num_frames = len(ref_img_metas[0])\n\n        if len(obj_feats.size()) == 6:\n            object_feats = self._query_fusion(obj_feats, num_imgs, num_frames)\n        else:\n            object_feats = obj_feats\n\n        if self.with_mask_init:\n            mask_preds = self._mask_init(object_feats, x, num_imgs)\n        else:\n            mask_preds = masks\n\n        cls_score = None\n        for stage in range(self.num_stages):\n            if stage == self.assign_stages:\n                object_feats = object_feats[:, None].repeat(1, num_frames, 1, 1, 1, 1)\n            mask_results = self._mask_forward(stage, x, object_feats, mask_preds)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score'] if mask_results['cls_score'] is not None else cls_score\n            object_feats = mask_results['object_feats']\n\n        num_classes = self.mask_head[-1].num_classes\n        results = []\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        for img_id in range(num_imgs):\n            result = []\n            cls_score_per_img = cls_score[img_id]\n            # h, quite tricky here, a bounding box can predict multiple results with different labels\n            scores_per_img, topk_indices = cls_score_per_img.flatten(0, 1).topk(\n                self.test_cfg.max_per_img, sorted=True)\n            mask_indices = topk_indices // num_classes\n            # Use the following when torch >= 1.9.0\n            # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='floor')\n            labels_per_img = topk_indices % num_classes\n            for frame_id in range(num_frames):\n                masks_per_img = scaled_mask_preds[img_id][frame_id][mask_indices]\n                single_result=self.mask_head[-1].get_seg_masks_tracking(\n                    masks_per_img, labels_per_img, scores_per_img,\n                    torch.arange(self.test_cfg.max_per_img),\n                    self.test_cfg, img_metas[img_id])\n                result.append(single_result)\n            results.append(result)\n        features = {\n            \"obj_feats\": object_feats,\n            \"x_feats\": x,\n            \"cls_scores\": cls_score,\n            \"masks\": mask_preds,\n        }\n        return results, features\n\n    def init_weights(self):\n        if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained' and self.init_cfg['prefix'] is not None:\n            from mmcv.cnn import initialize\n            self.logger.info(f\"Customized loading the tracker.\")\n            initialize(self, self.init_cfg)\n        else:\n            super().init_weights()\n"
  },
  {
    "path": "knet_vis/tracker/kernel_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, normal_init)\nfrom mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean\nfrom mmdet.models.builder import HEADS, build_loss, build_neck\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\n\n@HEADS.register_module()\nclass ConvKernelHeadVideo(nn.Module):\n    def __init__(self,\n                 num_proposals=100,\n                 in_channels=256,\n                 out_channels=256,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_seg_convs=1,\n                 num_loc_convs=1,\n                 att_dropout=False,\n                 localization_fpn=None,\n                 conv_kernel_size=1,\n                 norm_cfg=dict(type='GN', num_groups=32),\n                 semantic_fpn=True,\n                 train_cfg=None,\n                 num_classes=80,\n                 xavier_init_kernel=False,\n                 kernel_init_std=0.01,\n                 use_binary=False,\n                 proposal_feats_with_obj=False,\n                 loss_mask=None,\n                 loss_seg=None,\n                 loss_cls=None,\n                 loss_dice=None,\n                 loss_rank=None,\n                 feat_downsample_stride=1,\n                 feat_refine_stride=1,\n                 feat_refine=True,\n                 with_embed=False,\n                 feat_embed_only=False,\n                 conv_normal_init=False,\n                 mask_out_stride=4,\n                 hard_target=False,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 cat_stuff_mask=False,\n                 **kwargs):\n        super().__init__()\n        self.num_proposals = num_proposals\n        self.num_cls_fcs = num_cls_fcs\n        self.train_cfg = train_cfg\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_classes = num_classes\n        self.proposal_feats_with_obj = proposal_feats_with_obj\n        self.sampling = False\n        self.localization_fpn = build_neck(localization_fpn)\n        self.semantic_fpn = semantic_fpn\n        self.norm_cfg = norm_cfg\n        self.num_heads = num_heads\n        self.att_dropout = att_dropout\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.conv_kernel_size = conv_kernel_size\n        self.xavier_init_kernel = xavier_init_kernel\n        self.kernel_init_std = kernel_init_std\n        self.feat_downsample_stride = feat_downsample_stride\n        self.feat_refine_stride = feat_refine_stride\n        self.conv_normal_init = conv_normal_init\n        self.feat_refine = feat_refine\n        self.with_embed = with_embed\n        self.feat_embed_only = feat_embed_only\n        self.num_loc_convs = num_loc_convs\n        self.num_seg_convs = num_seg_convs\n        self.use_binary = use_binary\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n        self.cat_stuff_mask = cat_stuff_mask\n\n        if loss_mask is not None:\n            self.loss_mask = build_loss(loss_mask)\n        else:\n            self.loss_mask = loss_mask\n\n        if loss_dice is not None:\n            self.loss_dice = build_loss(loss_dice)\n        else:\n            self.loss_dice = loss_dice\n\n        if loss_seg is not None:\n            self.loss_seg = build_loss(loss_seg)\n        else:\n            self.loss_seg = loss_seg\n        if loss_cls is not None:\n            self.loss_cls = build_loss(loss_cls)\n        else:\n            self.loss_cls = loss_cls\n\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        if self.train_cfg:\n            self.assigner = build_assigner(self.train_cfg.assigner)\n            # use PseudoSampler when sampling is False\n            if self.sampling and hasattr(self.train_cfg, 'sampler'):\n                sampler_cfg = self.train_cfg.sampler\n            else:\n                sampler_cfg = dict(type='MaskPseudoSampler')\n            self.sampler = build_sampler(sampler_cfg, context=self)\n        self._init_layers()\n\n    def _init_layers(self):\n        \"\"\"Initialize a sparse set of proposal boxes and proposal features.\"\"\"\n        self.init_kernels = nn.Conv2d(\n            self.out_channels,\n            self.num_proposals,\n            self.conv_kernel_size,\n            padding=int(self.conv_kernel_size // 2),\n            bias=False)\n\n        if self.semantic_fpn:\n            if self.loss_seg.use_sigmoid:\n                self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes,\n                                          1)\n            else:\n                self.conv_seg = nn.Conv2d(self.out_channels,\n                                          self.num_classes + 1, 1)\n\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            self.ins_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,\n                padding=1,\n                norm_cfg=self.norm_cfg)\n            self.seg_downsample = ConvModule(\n                self.in_channels,\n                self.out_channels,\n                3,\n                stride=self.feat_refine_stride,\n                padding=1,\n                norm_cfg=self.norm_cfg)\n\n        self.loc_convs = nn.ModuleList()\n        for i in range(self.num_loc_convs):\n            self.loc_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n        self.seg_convs = nn.ModuleList()\n        for i in range(self.num_seg_convs):\n            self.seg_convs.append(\n                ConvModule(\n                    self.in_channels,\n                    self.out_channels,\n                    1,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        self.localization_fpn.init_weights()\n\n        if self.feat_downsample_stride > 1 and self.conv_normal_init:\n            logger = get_root_logger()\n            logger.info('Initialize convs in KPN head by normal std 0.01')\n            for conv in [self.loc_convs, self.seg_convs]:\n                for m in conv.modules():\n                    if isinstance(m, nn.Conv2d):\n                        normal_init(m, std=0.01)\n\n        if self.semantic_fpn:\n            bias_seg = bias_init_with_prob(0.01)\n            if self.loss_seg.use_sigmoid:\n                normal_init(self.conv_seg, std=0.01, bias=bias_seg)\n            else:\n                normal_init(self.conv_seg, mean=0, std=0.01)\n        if self.xavier_init_kernel:\n            logger = get_root_logger()\n            logger.info('Initialize kernels by xavier uniform')\n            nn.init.xavier_uniform_(self.init_kernels.weight)\n        else:\n            logger = get_root_logger()\n            logger.info(\n                f'Initialize kernels by normal std: {self.kernel_init_std}')\n            normal_init(self.init_kernels, mean=0, std=self.kernel_init_std)\n\n    def _decode_init_proposals(self, img, img_metas, ref_img_metas):\n        num_imgs = len(img_metas)\n        num_frames = len(ref_img_metas[0])\n\n        if self.localization_fpn.__class__.__name__.endswith('3D'):\n            localization_feats = self.localization_fpn(img, num_imgs, num_frames)\n        else:\n            localization_feats = self.localization_fpn(img)\n        if isinstance(localization_feats, list):\n            loc_feats = localization_feats[0]\n        else:\n            loc_feats = localization_feats\n        for conv in self.loc_convs:\n            loc_feats = conv(loc_feats)\n        if self.feat_downsample_stride > 1 and self.feat_refine:\n            loc_feats = self.ins_downsample(loc_feats)\n        mask_preds = self.init_kernels(loc_feats)\n\n        if self.semantic_fpn:\n            if isinstance(localization_feats, list):\n                semantic_feats = localization_feats[1]\n            else:\n                semantic_feats = localization_feats\n            for conv in self.seg_convs:\n                semantic_feats = conv(semantic_feats)\n            if self.feat_downsample_stride > 1 and self.feat_refine:\n                semantic_feats = self.seg_downsample(semantic_feats)\n        else:\n            semantic_feats = None\n\n        if semantic_feats is not None:\n            seg_preds = self.conv_seg(semantic_feats)\n        else:\n            seg_preds = None\n\n        proposal_feats = self.init_kernels.weight.clone()\n        proposal_feats = proposal_feats[None].expand(num_imgs * num_frames, *proposal_feats.size())\n\n        if semantic_feats is not None:\n            x_feats = semantic_feats + loc_feats\n        else:\n            x_feats = loc_feats\n\n        if self.proposal_feats_with_obj:\n            sigmoid_masks = mask_preds.sigmoid()\n            nonzero_inds = sigmoid_masks > 0.5\n            if self.use_binary:\n                sigmoid_masks = nonzero_inds.float()\n            else:\n                sigmoid_masks = nonzero_inds.float() * sigmoid_masks\n            obj_feats = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x_feats)\n\n        cls_scores = None\n\n        if self.proposal_feats_with_obj:\n            proposal_feats = proposal_feats + obj_feats.view(\n                num_imgs * num_frames, self.num_proposals, self.out_channels, 1, 1)\n\n        if self.cat_stuff_mask and not self.training:\n            mask_preds = torch.cat(\n                [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.\n                                                 num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames, *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      ref_img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_instance_ids=None,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n        \"\"\"Forward function in training stage.\"\"\"\n        num_imgs = len(img_metas)\n        num_frames = len(ref_img_metas[0])\n        results = self._decode_init_proposals(img, img_metas, ref_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results\n        if self.feat_downsample_stride > 1:\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=self.feat_downsample_stride,\n                mode='bilinear',\n                align_corners=False)\n            if seg_preds is not None:\n                scaled_seg_preds = F.interpolate(\n                    seg_preds,\n                    scale_factor=self.feat_downsample_stride,\n                    mode='bilinear',\n                    align_corners=False)\n        else:\n            scaled_mask_preds = mask_preds\n            scaled_seg_preds = seg_preds\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        sampling_results = []\n        if cls_scores is None:\n            detached_cls_scores = [[None] * num_frames] * num_imgs\n        else:\n            detached_cls_scores = cls_scores.detach()\n\n        for i in range(num_imgs):\n            for j in range(num_frames):\n                assign_result = self.assigner.assign(scaled_mask_preds[i * num_frames + j].detach(),\n                                                     detached_cls_scores[i][j],\n                                                     gt_masks[i][j], gt_labels[i][:,1][gt_labels[i][:,0]==j],\n                                                     ref_img_metas[i][j])\n                sampling_result = self.sampler.sample(assign_result,\n                                                      scaled_mask_preds[i * num_frames + j],\n                                                      gt_masks[i][j])\n                sampling_results.append(sampling_result)\n\n        mask_targets = self.get_targets(\n            sampling_results,\n            self.train_cfg,\n            True,\n            gt_sem_seg=gt_sem_seg,\n            gt_sem_cls=gt_sem_cls)\n\n        losses = self.loss(scaled_mask_preds, cls_scores, scaled_seg_preds, proposal_feats, *mask_targets)\n\n        if self.cat_stuff_mask and self.training:\n            mask_preds = torch.cat([mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)\n            stuff_kernels = self.conv_seg.weight[self.num_thing_classes:].clone()\n            stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames, *stuff_kernels.size())\n            proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)\n\n        return losses, proposal_feats, x_feats, mask_preds, cls_scores\n\n    def loss(self,\n             mask_pred,\n             cls_scores,\n             seg_preds,\n             proposal_feats,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             seg_targets,\n             reduction_override=None,\n             **kwargs):\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n\n        if cls_scores is not None:\n            num_pos = pos_inds.sum().float()\n            avg_factor = reduce_mean(num_pos)\n            assert mask_pred.shape[0] == cls_scores.shape[0]\n            assert mask_pred.shape[1] == cls_scores.shape[1]\n            losses['loss_rpn_cls'] = self.loss_cls(\n                cls_scores.view(num_preds, -1),\n                labels,\n                label_weights,\n                avg_factor=avg_factor,\n                reduction_override=reduction_override)\n            losses['rpn_pos_acc'] = accuracy(\n                cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds])\n\n        bool_pos_inds = pos_inds.type(torch.bool)\n        # 0~self.num_classes-1 are FG, self.num_classes is BG\n        # do not perform bounding box regression for BG anymore.\n        H, W = mask_pred.shape[-2:]\n        if pos_inds.any():\n            pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds]\n            pos_mask_targets = mask_targets[bool_pos_inds]\n            losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n            losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n            if self.loss_rank is not None:\n                batch_size = mask_pred.size(0)\n                rank_target = mask_targets.new_full((batch_size, H, W),\n                                                    self.ignore_label,\n                                                    dtype=torch.long)\n                rank_inds = pos_inds.view(batch_size,\n                                          -1).nonzero(as_tuple=False)\n                batch_mask_targets = mask_targets.view(batch_size, -1, H,\n                                                       W).bool()\n                for i in range(batch_size):\n                    curr_inds = (rank_inds[:, 0] == i)\n                    curr_rank = rank_inds[:, 1][curr_inds]\n                    for j in curr_rank:\n                        rank_target[i][batch_mask_targets[i][j]] = j\n                losses['loss_rpn_rank'] = self.loss_rank(\n                    mask_pred, rank_target, ignore_index=self.ignore_label)\n\n        else:\n            losses['loss_rpn_mask'] = mask_pred.sum() * 0\n            losses['loss_rpn_dice'] = mask_pred.sum() * 0\n            if self.loss_rank is not None:\n                losses['loss_rank'] = mask_pred.sum() * 0\n\n        if seg_preds is not None:\n            if self.loss_seg.use_sigmoid:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(\n                    -1, cls_channel,\n                    H * W).permute(0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                num_dense_pos = (flatten_seg_target >= 0) & (\n                    flatten_seg_target < bg_class_ind)\n                num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0)\n                losses['loss_rpn_seg'] = self.loss_seg(\n                    flatten_seg,\n                    flatten_seg_target,\n                    avg_factor=num_dense_pos)\n            else:\n                cls_channel = seg_preds.shape[1]\n                flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(\n                    0, 2, 1).reshape(-1, cls_channel)\n                flatten_seg_target = seg_targets.view(-1)\n                losses['loss_rpn_seg'] = self.loss_seg(flatten_seg,\n                                                       flatten_seg_target)\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros(num_samples)\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        seg_targets = pos_mask.new_full((H, W),\n                                        self.num_classes,\n                                        dtype=torch.long)\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            gt_sem_seg = gt_sem_seg.bool()\n            for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls):\n                seg_targets[sem_mask] = sem_cls.long()\n\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            mask_targets[pos_inds, ...] = pos_gt_mask\n            mask_weights[pos_inds, ...] = 1\n            for i in range(num_pos):\n                seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i]\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def get_targets(self,\n                    sampling_results,\n                    rpn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        num_imgs = len(sampling_results)\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * num_imgs\n            gt_sem_cls = [None] * num_imgs\n        results = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rpn_train_cfg)\n        (labels, label_weights, mask_targets, mask_weights,\n         seg_targets) = results\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n            seg_targets = torch.stack(seg_targets, 0)\n        return labels, label_weights, mask_targets, mask_weights, seg_targets\n\n    def simple_test_rpn(self, img, img_metas, ref_img_metas):\n        \"\"\"Forward function in testing stage.\"\"\"\n        return self._decode_init_proposals(img, img_metas, ref_img_metas)\n\n    def forward_dummy(self, img, img_metas, ref_img_metas):\n        \"\"\"Dummy forward function.\n\n        Used in flops calculation.\n        \"\"\"\n        return self._decode_init_proposals(img, img_metas,ref_img_metas)\n"
  },
  {
    "path": "knet_vis/tracker/kernel_iter_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mmdet.core import build_assigner, build_sampler\nfrom mmdet.datasets.coco_panoptic import INSTANCE_OFFSET\nfrom mmdet.models.builder import HEADS, build_head\nfrom mmdet.models.roi_heads import BaseRoIHead\n\nfrom knet_vis.det.mask_pseudo_sampler import MaskPseudoSampler\n\n\n@HEADS.register_module()\nclass KernelIterHeadVideo(BaseRoIHead):\n    def __init__(self,\n                 num_stages=6,\n                 recursive=False,\n                 assign_stages=5,\n                 stage_loss_weights=(1, 1, 1, 1, 1, 1),\n                 proposal_feature_channel=256,\n                 merge_cls_scores=False,\n                 do_panoptic=False,\n                 post_assign=False,\n                 hard_target=False,\n                 num_proposals=100,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 thing_label_in_seg=0,\n                 mask_head=dict(\n                     type='KernelUpdateHead',\n                     num_classes=80,\n                     num_fcs=2,\n                     num_heads=8,\n                     num_cls_fcs=1,\n                     num_reg_fcs=3,\n                     feedforward_channels=2048,\n                     hidden_channels=256,\n                     dropout=0.0,\n                     roi_feat_size=7,\n                     ffn_act_cfg=dict(type='ReLU', inplace=True)),\n                 mask_out_stride=4,\n                 train_cfg=None,\n                 test_cfg=None,\n                 **kwargs):\n        assert mask_head is not None\n        assert len(stage_loss_weights) == num_stages\n        self.num_stages = num_stages\n        self.stage_loss_weights = stage_loss_weights\n        self.proposal_feature_channel = proposal_feature_channel\n        self.merge_cls_scores = merge_cls_scores\n        self.recursive = recursive\n        self.post_assign = post_assign\n        self.mask_out_stride = mask_out_stride\n        self.hard_target = hard_target\n        self.assign_stages = assign_stages\n        self.do_panoptic = do_panoptic\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.num_classes = num_thing_classes + num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.num_proposals = num_proposals\n        super().__init__(\n            mask_head=mask_head,\n            train_cfg=train_cfg,\n            test_cfg=test_cfg,\n            **kwargs)\n        # train_cfg would be None when run the test.py\n        if train_cfg is not None:\n            for stage in range(num_stages):\n                assert isinstance(self.mask_sampler[stage], MaskPseudoSampler), \\\n                    'Sparse Mask only support `MaskPseudoSampler`'\n\n    def init_bbox_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize box head and box roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of box roi extractor.\n            mask_head (dict): Config of box in box head.\n        \"\"\"\n        pass\n\n    def init_assigner_sampler(self):\n        \"\"\"Initialize assigner and sampler for each stage.\"\"\"\n        self.mask_assigner = []\n        self.mask_sampler = []\n        if self.train_cfg is not None:\n            for idx, rcnn_train_cfg in enumerate(self.train_cfg):\n                self.mask_assigner.append(\n                    build_assigner(rcnn_train_cfg.assigner))\n                self.current_stage = idx\n                self.mask_sampler.append(\n                    build_sampler(rcnn_train_cfg.sampler, context=self))\n\n    def init_weights(self):\n        for i in range(self.num_stages):\n            self.mask_head[i].init_weights()\n\n    def init_mask_head(self, mask_roi_extractor, mask_head):\n        \"\"\"Initialize mask head and mask roi extractor.\n\n        Args:\n            mask_roi_extractor (dict): Config of mask roi extractor.\n            mask_head (dict): Config of mask in mask head.\n        \"\"\"\n        self.mask_head = nn.ModuleList()\n        if not isinstance(mask_head, list):\n            mask_head = [mask_head for _ in range(self.num_stages)]\n        assert len(mask_head) == self.num_stages\n        for head in mask_head:\n            self.mask_head.append(build_head(head))\n        if self.recursive:\n            for i in range(self.num_stages):\n                self.mask_head[i] = self.mask_head[0]\n\n    def _mask_forward(self, stage, x, object_feats, mask_preds, img_metas=None):\n        mask_head = self.mask_head[stage]\n        cls_score, mask_preds, object_feats = mask_head(x, object_feats, mask_preds, img_metas=img_metas)\n        if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1 or self.training):\n            scaled_mask_preds = F.interpolate(\n                mask_preds,\n                scale_factor=mask_head.mask_upsample_stride,\n                align_corners=False,\n                mode='bilinear'\n            )\n        else:\n            scaled_mask_preds = mask_preds\n\n        mask_results = dict(\n            cls_score=cls_score,\n            mask_preds=mask_preds,\n            scaled_mask_preds=scaled_mask_preds,\n            object_feats=object_feats\n        )\n        return mask_results\n\n    def forward_train(self,\n                      x,\n                      proposal_feats,\n                      mask_preds,\n                      cls_score,\n                      ref_img_metas,\n                      gt_masks,\n                      gt_labels,\n                      gt_bboxes_ignore=None,\n                      imgs_whwh=None,\n                      gt_bboxes=None,\n                      gt_sem_seg=None,\n                      gt_sem_cls=None):\n\n        num_imgs = len(ref_img_metas)\n        num_frames = len(ref_img_metas[0])\n        if self.mask_head[0].mask_upsample_stride > 1:\n            prev_mask_preds = F.interpolate(\n                mask_preds.detach(),\n                scale_factor=self.mask_head[0].mask_upsample_stride,\n                mode='bilinear',\n                align_corners=False)\n        else:\n            prev_mask_preds = mask_preds.detach()\n\n        if cls_score is not None:\n            prev_cls_score = cls_score.detach()\n        else:\n            prev_cls_score = None\n\n        if self.hard_target:\n            gt_masks = [x.bool().float() for x in gt_masks]\n        else:\n            gt_masks = gt_masks\n\n        object_feats = proposal_feats\n        all_stage_loss = {}\n        all_stage_mask_results = []\n        assign_results = []\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats, mask_preds, img_metas=None)\n            all_stage_mask_results.append(mask_results)\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n            cls_score = mask_results['cls_score']\n            object_feats = mask_results['object_feats']\n\n            if self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n            sampling_results = []\n            if stage < self.assign_stages:\n                assign_results = []\n            for i in range(num_imgs):\n                for j in range(num_frames):\n                    if stage < self.assign_stages:\n                        mask_for_assign = prev_mask_preds[i * num_frames + j][:self.num_proposals]\n                        if prev_cls_score is not None:\n                            cls_for_assign = prev_cls_score[i * num_frames + j][:self.num_proposals, :self.num_thing_classes]\n                        else:\n                            cls_for_assign = None\n                        assign_result = self.mask_assigner[stage].assign(\n                            mask_for_assign, cls_for_assign, gt_masks[i][j],\n                            gt_labels[i][:,1][gt_labels[i][:,0]==j], img_meta=None)\n                        assign_results.append(assign_result)\n                    sampling_result = self.mask_sampler[stage].sample(\n                        assign_results[i * num_frames + j], scaled_mask_preds[i * num_frames + j], gt_masks[i][j])\n                    sampling_results.append(sampling_result)\n            mask_targets = self.mask_head[stage].get_targets(\n                sampling_results,\n                self.train_cfg[stage],\n                True,\n                gt_sem_seg=gt_sem_seg,\n                gt_sem_cls=gt_sem_cls)\n\n            single_stage_loss = self.mask_head[stage].loss(\n                object_feats,\n                cls_score,\n                scaled_mask_preds,\n                *mask_targets,\n                imgs_whwh=imgs_whwh)\n            for key, value in single_stage_loss.items():\n                all_stage_loss[f's{stage}_{key}'] = value * \\\n                                    self.stage_loss_weights[stage]\n\n            if not self.post_assign:\n                prev_mask_preds = scaled_mask_preds.detach()\n                prev_cls_score = cls_score.detach()\n\n        bs_nf, num_query, c, ks1, ks2 = object_feats.size()\n        bs_nf2, c2, h, w = x.size()\n        assert ks1 == ks2\n        assert bs_nf == bs_nf2\n        assert bs_nf == num_frames * num_imgs\n        assert c == c2\n        features = {\n            \"obj_feats\" : object_feats.reshape((num_imgs, num_frames, num_query, c, ks1, ks2)),\n            # \"x_feats\":self.mask_head[-1].feat_transform(x).reshape((num_imgs, num_frames, c, h, w)),\n            \"x_feats\": x.reshape((num_imgs, num_frames, c, h, w)),\n            \"cls_scores\": cls_score.reshape((num_imgs, num_frames, num_query, self.num_classes)),\n            \"masks\": mask_preds.reshape((num_imgs, num_frames, num_query, h, w)),\n        }\n        return all_stage_loss, features\n\n    def simple_test(self,\n                    x,\n                    proposal_feats,\n                    mask_preds,\n                    cls_score,\n                    img_metas,\n                    ref_img_metas,\n                    imgs_whwh=None,\n                    rescale=False):\n\n        # Decode initial proposals\n        num_imgs = len(ref_img_metas)\n        num_frames = len(ref_img_metas[0])\n        # num_proposals = proposal_feats.size(1)\n\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds)\n            object_feats = mask_results['object_feats']\n            cls_score = mask_results['cls_score']\n            mask_preds = mask_results['mask_preds']\n            scaled_mask_preds = mask_results['scaled_mask_preds']\n\n        num_classes = self.mask_head[-1].num_classes\n        results = []\n\n        if self.mask_head[-1].loss_cls.use_sigmoid:\n            cls_score = cls_score.sigmoid()\n        else:\n            cls_score = cls_score.softmax(-1)[..., :-1]\n\n        bs_nf, num_query, c, ks1, ks2 = object_feats.size()\n        bs_nf2, c2, h, w = x.size()\n        assert ks1 == ks2\n        assert bs_nf == bs_nf2\n        assert bs_nf == num_frames * num_imgs\n        assert c == c2\n        features = {\n            \"obj_feats\": object_feats.reshape((num_imgs, num_frames, num_query, c, ks1, ks2)),\n            # \"x_feats\":self.mask_head[-1].feat_transform(x).reshape((num_imgs, num_frames, c, h, w)),\n            \"x_feats\": x.reshape((num_imgs, num_frames, c, h, w)),\n            \"cls_scores\": cls_score.reshape((num_imgs, num_frames, num_query, self.num_classes)),\n            \"masks\": mask_preds.reshape((num_imgs, num_frames, num_query, h, w)),\n        }\n\n        if self.do_panoptic:\n            raise NotImplementedError\n            # for img_id in range(num_imgs):\n            #     single_result = self.get_panoptic(cls_score[img_id],\n            #                                       scaled_mask_preds[img_id],\n            #                                       self.test_cfg,\n            #                                       ref_img_metas[img_id])\n            #     results.append(single_result)\n        else:\n            for img_id in range(num_imgs):\n                for frame_id in range(num_frames):\n                    cls_score_per_img = cls_score[img_id * num_frames + frame_id]\n                    # h, quite tricky here, a bounding box can predict multiple results with different labels\n                    scores_per_img, topk_indices = cls_score_per_img.flatten(0, 1).topk(\n                            self.test_cfg.max_per_img, sorted=True)\n                    mask_indices = topk_indices // num_classes\n                    # Use the following when torch >= 1.9.0\n                    # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='floor')\n                    labels_per_img = topk_indices % num_classes\n                    masks_per_img = scaled_mask_preds[img_id * num_frames + frame_id][mask_indices]\n                    single_result = self.mask_head[-1].get_seg_masks(\n                        masks_per_img, labels_per_img, scores_per_img,\n                        self.test_cfg, img_metas[img_id])\n                    results.append(single_result)\n        return results, features\n\n    def aug_test(self, features, proposal_list, img_metas, rescale=False):\n        raise NotImplementedError('SparseMask does not support `aug_test`')\n\n    def forward_dummy(self, x, proposal_boxes, proposal_feats, img_metas):\n        \"\"\"Dummy forward function when do the flops computing.\"\"\"\n        all_stage_mask_results = []\n        num_imgs = len(img_metas)\n        num_proposals = proposal_feats.size(1)\n        C, H, W = x.shape[-3:]\n        mask_preds = proposal_feats.bmm(x.view(num_imgs, C, -1)).view(\n            num_imgs, num_proposals, H, W)\n        object_feats = proposal_feats\n        for stage in range(self.num_stages):\n            mask_results = self._mask_forward(stage, x, object_feats,\n                                              mask_preds, img_metas)\n            all_stage_mask_results.append(mask_results)\n        return all_stage_mask_results\n\n    def get_panoptic(self, cls_scores, mask_preds, test_cfg, img_meta):\n        # resize mask predictions back\n        scores = cls_scores[:self.num_proposals][:, :self.num_thing_classes]\n        thing_scores, thing_labels = scores.max(dim=1)\n        stuff_scores = cls_scores[\n            self.num_proposals:][:, self.num_thing_classes:].diag()\n        stuff_labels = torch.arange(\n            0, self.num_stuff_classes) + self.num_thing_classes\n        stuff_labels = stuff_labels.to(thing_labels.device)\n\n        total_masks = self.mask_head[-1].rescale_masks(mask_preds, img_meta)\n        total_scores = torch.cat([thing_scores, stuff_scores], dim=0)\n        total_labels = torch.cat([thing_labels, stuff_labels], dim=0)\n\n        panoptic_result = self.merge_stuff_thing(total_masks, total_labels,\n                                                 total_scores,\n                                                 test_cfg.merge_stuff_thing)\n        return dict(pan_results=panoptic_result)\n\n    def merge_stuff_thing(self,\n                          total_masks,\n                          total_labels,\n                          total_scores,\n                          merge_cfg=None):\n\n        H, W = total_masks.shape[-2:]\n        panoptic_seg = total_masks.new_full((H, W),\n                                            self.num_classes,\n                                            dtype=torch.long)\n\n        cur_prob_masks = total_scores.view(-1, 1, 1) * total_masks\n        cur_mask_ids = cur_prob_masks.argmax(0)\n\n        # sort instance outputs by scores\n        sorted_inds = torch.argsort(-total_scores)\n        current_segment_id = 0\n\n        for k in sorted_inds:\n            pred_class = total_labels[k].item()\n            isthing = pred_class < self.num_thing_classes\n            if isthing and total_scores[k] < merge_cfg.instance_score_thr:\n                continue\n\n            mask = cur_mask_ids == k\n            mask_area = mask.sum().item()\n            original_area = (total_masks[k] >= 0.5).sum().item()\n\n            if mask_area > 0 and original_area > 0:\n                if mask_area / original_area < merge_cfg.overlap_thr:\n                    continue\n\n                panoptic_seg[mask] = total_labels[k] \\\n                    + current_segment_id * INSTANCE_OFFSET\n                current_segment_id += 1\n\n        return panoptic_seg.cpu().numpy()\n"
  },
  {
    "path": "knet_vis/tracker/kernel_update_head.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (ConvModule, bias_init_with_prob, build_activation_layer,\n                      build_norm_layer)\nfrom mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,\n                                         build_transformer_layer)\nfrom mmcv.runner import force_fp32\n\nfrom mmdet.core import multi_apply\nfrom mmdet.models.builder import HEADS, build_loss\nfrom mmdet.models.dense_heads.atss_head import reduce_mean\nfrom mmdet.models.losses import accuracy\nfrom mmdet.utils import get_root_logger\n\nfrom mmtrack.transform import outs2results\n\n@HEADS.register_module()\nclass KernelUpdateHeadVideo(nn.Module):\n\n    def __init__(self,\n                 with_cls=True,\n                 num_proposals=100,\n                 num_classes=80,\n                 num_ffn_fcs=2,\n                 num_heads=8,\n                 num_cls_fcs=1,\n                 num_mask_fcs=3,\n                 feedforward_channels=2048,\n                 in_channels=256,\n                 out_channels=256,\n                 dropout=0.0,\n                 mask_thr=0.5,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 ffn_act_cfg=dict(type='ReLU', inplace=True),\n                 conv_kernel_size=3,\n                 feat_transform_cfg=None,\n                 hard_mask_thr=0.5,\n                 kernel_init=False,\n                 with_ffn=True,\n                 mask_out_stride=4,\n                 relative_coors=False,\n                 relative_coors_off=False,\n                 feat_gather_stride=1,\n                 mask_transform_stride=1,\n                 mask_upsample_stride=1,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 ignore_label=255,\n                 thing_label_in_seg=0,\n                 # query fusion\n                 query_merge_method='mean',\n\n                 kernel_updator_cfg=dict(\n                     type='DynamicConv',\n                     in_channels=256,\n                     feat_channels=64,\n                     out_channels=256,\n                     input_feat_shape=1,\n                     act_cfg=dict(type='ReLU', inplace=True),\n                     norm_cfg=dict(type='LN')),\n                 loss_rank=None,\n                 loss_mask=dict(\n                     type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),\n                 loss_dice=dict(type='DiceLoss', loss_weight=3.0),\n                 loss_cls=dict(\n                     type='FocalLoss',\n                     use_sigmoid=True,\n                     gamma=2.0,\n                     alpha=0.25,\n                     loss_weight=2.0)):\n        super().__init__()\n        self.num_proposals = num_proposals\n        self.num_classes = num_classes\n        self.loss_cls = build_loss(loss_cls)\n        self.loss_mask = build_loss(loss_mask)\n        self.loss_dice = build_loss(loss_dice)\n        if loss_rank is not None:\n            self.loss_rank = build_loss(loss_rank)\n        else:\n            self.loss_rank = loss_rank\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.mask_thr = mask_thr\n        self.fp16_enabled = False\n        self.dropout = dropout\n\n        self.num_heads = num_heads\n        self.hard_mask_thr = hard_mask_thr\n        self.kernel_init = kernel_init\n        self.with_ffn = with_ffn\n        self.mask_out_stride = mask_out_stride\n        self.relative_coors = relative_coors\n        self.relative_coors_off = relative_coors_off\n        self.conv_kernel_size = conv_kernel_size\n        self.feat_gather_stride = feat_gather_stride\n        self.mask_transform_stride = mask_transform_stride\n        self.mask_upsample_stride = mask_upsample_stride\n\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.ignore_label = ignore_label\n        self.thing_label_in_seg = thing_label_in_seg\n\n        self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,\n                                            num_heads, dropout)\n        self.attention_norm = build_norm_layer(\n            dict(type='LN'), in_channels * conv_kernel_size**2)[1]\n\n        self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)\n\n        if feat_transform_cfg is not None:\n            kernel_size = feat_transform_cfg.pop('kernel_size', 1)\n            self.feat_transform = ConvModule(\n                in_channels,\n                in_channels,\n                kernel_size,\n                stride=feat_gather_stride,\n                padding=int(feat_gather_stride // 2),\n                **feat_transform_cfg)\n        else:\n            self.feat_transform = None\n\n        if self.with_ffn:\n            self.ffn = FFN(\n                in_channels,\n                feedforward_channels,\n                num_ffn_fcs,\n                act_cfg=ffn_act_cfg,\n                ffn_drop=dropout)\n            self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]\n\n        self.with_cls = with_cls\n        if self.with_cls:\n            self.cls_fcs = nn.ModuleList()\n            for _ in range(num_cls_fcs):\n                self.cls_fcs.append(\n                    nn.Linear(in_channels, in_channels, bias=False))\n                self.cls_fcs.append(\n                    build_norm_layer(dict(type='LN'), in_channels)[1])\n                self.cls_fcs.append(build_activation_layer(act_cfg))\n\n            if self.loss_cls.use_sigmoid:\n                self.fc_cls = nn.Linear(in_channels, self.num_classes)\n            else:\n                self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)\n\n\n        # query fusion\n        self.query_merge_method = query_merge_method\n        if self.query_merge_method == 'attention' and self.with_cls:\n            _num_head = 8\n            _drop_out = 0.\n            self.query_merge_attn = MultiheadAttention(self.in_channels, _num_head, _drop_out, batch_first=True)\n            self.query_merge_norm = build_norm_layer(dict(type='LN'),  self.in_channels)[1]\n            self.query_merge_ffn = FFN(\n                self.in_channels,\n                self.in_channels * 8,\n                num_ffn_fcs=2,\n                act_cfg=dict(type='ReLU', inplace=True),\n                ffn_drop=0.)\n            self.query_merge_ffn_norm = build_norm_layer(dict(type='LN'), self.in_channels)[1]\n        elif self.query_merge_method == 'attention_pos' and self.with_cls:\n            _num_head = 8\n            _drop_out = 0.\n            self.query_merge_attn = MultiheadAttention(self.in_channels, _num_head, _drop_out, batch_first=True)\n            self.query_merge_norm = build_norm_layer(dict(type='LN'), self.in_channels)[1]\n            self.query_merge_ffn = FFN(\n                self.in_channels,\n                self.in_channels * 8,\n                num_ffn_fcs=2,\n                act_cfg=dict(type='ReLU', inplace=True),\n                ffn_drop=0.)\n            self.query_merge_ffn_norm = build_norm_layer(dict(type='LN'), self.in_channels)[1]\n\n        self.mask_fcs = nn.ModuleList()\n        for _ in range(num_mask_fcs):\n            self.mask_fcs.append(\n                nn.Linear(in_channels, in_channels, bias=False))\n            self.mask_fcs.append(\n                build_norm_layer(dict(type='LN'), in_channels)[1])\n            self.mask_fcs.append(build_activation_layer(act_cfg))\n\n        self.fc_mask = nn.Linear(in_channels, out_channels)\n\n    def init_weights(self):\n        \"\"\"Use xavier initialization for all weight parameter and set\n        classification head bias as a specific value when use focal loss.\"\"\"\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n            else:\n                # adopt the default initialization for\n                # the weight and bias of the layer norm\n                pass\n        if self.loss_cls.use_sigmoid:\n            bias_init = bias_init_with_prob(0.01)\n            nn.init.constant_(self.fc_cls.bias, bias_init)\n        if self.kernel_init:\n            logger = get_root_logger()\n            logger.info(\n                'mask kernel in mask head is normal initialized by std 0.01')\n            nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)\n\n    def forward(self,\n                x,\n                proposal_feat,\n                mask_preds,\n                prev_cls_score=None,\n                mask_shape=None,\n                img_metas=None,\n                pos=None):\n        if len(proposal_feat.size()) == 6:\n            assert not self.with_cls\n            is_gather_query = False\n            N, _, num_proposals = proposal_feat.shape[:3]\n        else:\n            assert self.with_cls\n            is_gather_query = True\n            N, num_proposals = proposal_feat.shape[:2]\n        assert self.num_proposals == num_proposals\n        _, num_frames ,C, H, W = x.size()\n        if self.feat_transform is not None:\n            x = self.feat_transform(x.reshape((N * num_frames, C, H, W))).reshape((N, num_frames, C, H, W))\n\n        mask_h, mask_w = mask_preds.shape[-2:]\n        if mask_h != H or mask_w != W:\n            gather_mask = F.interpolate(\n                mask_preds.reshape((N * num_proposals, C, H, W)),\n                (H, W), align_corners=False, mode='bilinear').reshape((N, num_frames, C, H, W))\n        else:\n            gather_mask = mask_preds\n\n        sigmoid_masks = gather_mask.sigmoid()\n        nonzero_inds = sigmoid_masks > self.hard_mask_thr\n        sigmoid_masks = nonzero_inds.float()\n\n        # einsum is faster than bmm by 30%\n        if is_gather_query:\n            # x_feat = torch.einsum('bfnhw,bfchw->bnc', sigmoid_masks, x)\n            if self.query_merge_method == 'mean':\n                x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x).mean(1)\n            elif self.query_merge_method == 'attention':\n                x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x)\n                x_feat = x_feat.reshape((N, num_frames * num_proposals, self.in_channels))\n                assert proposal_feat.size()[-2:] == (1,1), \"Only supporting kernel size = 1\"\n                init_query = proposal_feat.reshape(N, num_proposals, self.in_channels).detach()\n                x_feat = self.query_merge_attn(query=init_query, key=x_feat, value=x_feat)\n                x_feat = self.query_merge_norm(x_feat)\n                x_feat = self.query_merge_ffn_norm(self.query_merge_ffn(x_feat))\n            elif self.query_merge_method == 'attention_pos':\n                x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x)\n                x_feat = x_feat.reshape((N, num_frames * num_proposals, self.in_channels))\n                assert proposal_feat.size()[-2:] == (1, 1), \"Only supporting kernel size = 1\"\n                init_query = proposal_feat.reshape(N, num_proposals, self.in_channels).detach()\n                query_pos = pos.repeat(N, 1, 1)\n                key_pos = query_pos.repeat(1, num_frames, 1)\n                x_feat = self.query_merge_attn(query=init_query, key=x_feat, value=x_feat,\n                                               query_pos=query_pos, key_pos=key_pos)\n                x_feat = self.query_merge_norm(x_feat)\n                x_feat = self.query_merge_ffn_norm(self.query_merge_ffn(x_feat))\n            else:\n                raise NotImplementedError\n        else:\n            x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x)\n\n        # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]\n        if is_gather_query:\n            proposal_feat = proposal_feat.reshape(N, num_proposals, self.in_channels, -1).permute(0, 1, 3, 2)\n            obj_feat = self.kernel_update_conv(x_feat, proposal_feat)\n        else:\n            proposal_feat = proposal_feat.reshape(N * num_frames, num_proposals, self.in_channels, -1).permute(0, 1, 3, 2)\n            obj_feat = self.kernel_update_conv(x_feat.reshape(N * num_frames, num_proposals, C), proposal_feat)\n            N *= num_frames\n\n        # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]\n        obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)\n        obj_feat = self.attention_norm(self.attention(obj_feat))\n        # [N, B, K*K*C] -> [B, N, K*K*C]\n        obj_feat = obj_feat.permute(1, 0, 2)\n\n        # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]\n        obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)\n\n        # FFN\n        if self.with_ffn:\n            obj_feat = self.ffn_norm(self.ffn(obj_feat))\n\n        mask_feat = obj_feat\n\n        if is_gather_query:\n            cls_feat = obj_feat.sum(-2)\n            for cls_layer in self.cls_fcs:\n                cls_feat = cls_layer(cls_feat)\n            cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)\n        else:\n            cls_score = None\n\n        for reg_layer in self.mask_fcs:\n            mask_feat = reg_layer(mask_feat)\n        # [B, N, K*K, C] -> [B, N, C, K*K]\n        mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)\n\n        if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):\n            mask_x = F.interpolate(\n                x, scale_factor=0.5, mode='bilinear', align_corners=False)\n            H, W = mask_x.shape[-2:]\n            raise NotImplementedError\n        else:\n            mask_x = x\n        # group conv is 5x faster than unfold and uses about 1/5 memory\n        # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms\n        # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369\n        # fold_x = F.unfold(\n        #     mask_x,\n        #     self.conv_kernel_size,\n        #     padding=int(self.conv_kernel_size // 2))\n        # mask_feat = mask_feat.reshape(N, num_proposals, -1)\n        # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)\n        # [B, N, C, K*K] -> [B*N, C, K, K]\n        mask_feat = mask_feat.reshape(N, num_proposals, C,\n                                      self.conv_kernel_size,\n                                      self.conv_kernel_size)\n        # [B, C, H, W] -> [1, B*C, H, W]\n        if is_gather_query:\n            new_mask_preds = []\n            for i in range(N):\n                new_mask_preds.append(\n                    F.conv2d(\n                        mask_x[i],\n                        mask_feat[i],\n                        padding=int(self.conv_kernel_size // 2)))\n\n            new_mask_preds = torch.stack(new_mask_preds, dim=0)\n            assert new_mask_preds.size() == (N, num_frames, num_proposals, H, W)\n        else:\n            N = N // num_frames\n            new_mask_preds = []\n            for i in range(N):\n                for j in range(num_frames):\n                    new_mask_preds.append(\n                        F.conv2d(\n                            mask_x[i][j][None],\n                            mask_feat[i * num_frames + j],\n                            padding=int(self.conv_kernel_size // 2)))\n            new_mask_preds = torch.cat(new_mask_preds, dim=0)\n            new_mask_preds = new_mask_preds.reshape(N, num_frames, num_proposals, H, W)\n            assert new_mask_preds.size() == (N, num_frames, num_proposals, H, W)\n        if self.mask_transform_stride == 2:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                scale_factor=2,\n                mode='bilinear',\n                align_corners=False)\n            raise NotImplementedError\n\n        if mask_shape is not None and mask_shape[0] != H:\n            new_mask_preds = F.interpolate(\n                new_mask_preds,\n                mask_shape,\n                align_corners=False,\n                mode='bilinear')\n            raise NotImplementedError\n        if is_gather_query:\n            return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(\n                N, num_proposals, self.in_channels, self.conv_kernel_size,\n                self.conv_kernel_size)\n        else:\n            return None, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(\n                N, num_frames , num_proposals, self.in_channels, self.conv_kernel_size, self.conv_kernel_size)\n\n    @force_fp32(apply_to=('cls_score', 'mask_pred'))\n    def loss(self,\n             object_feats,\n             cls_score,\n             mask_pred,\n             labels,\n             label_weights,\n             mask_targets,\n             mask_weights,\n             imgs_whwh=None,\n             reduction_override=None,\n             **kwargs):\n\n        losses = dict()\n        bg_class_ind = self.num_classes\n        # note in spare rcnn num_gt == num_pos\n        pos_inds = (labels >= 0) & (labels < bg_class_ind)\n        num_pos = pos_inds.sum().float()\n        avg_factor = reduce_mean(num_pos).clamp_(min=1.0)\n\n        num_preds = mask_pred.shape[0] * mask_pred.shape[1]\n        if cls_score is not None:\n            assert mask_pred.shape[0] == cls_score.shape[0]\n            assert mask_pred.shape[1] == cls_score.shape[1]\n\n        if cls_score is not None:\n            if cls_score.numel() > 0:\n                losses['loss_cls'] = self.loss_cls(\n                    cls_score.view(num_preds, -1),\n                    labels,\n                    label_weights,\n                    avg_factor=avg_factor,\n                    reduction_override=reduction_override)\n                losses['pos_acc'] = accuracy(\n                    cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds])\n        if mask_pred is not None:\n            bool_pos_inds = pos_inds.type(torch.bool)\n            # 0~self.num_classes-1 are FG, self.num_classes is BG\n            # do not perform bounding box regression for BG anymore.\n            H, W = mask_pred.shape[-2:]\n            if pos_inds.any():\n                pos_mask_pred = mask_pred.reshape(num_preds, H,\n                                                  W)[bool_pos_inds]\n                pos_mask_targets = mask_targets[bool_pos_inds]\n                losses['loss_mask'] = self.loss_mask(pos_mask_pred,\n                                                     pos_mask_targets)\n                losses['loss_dice'] = self.loss_dice(pos_mask_pred,\n                                                     pos_mask_targets)\n\n                if self.loss_rank is not None:\n                    batch_size = mask_pred.size(0)\n                    rank_target = mask_targets.new_full((batch_size, H, W),\n                                                        self.ignore_label,\n                                                        dtype=torch.long)\n                    rank_inds = pos_inds.view(batch_size,\n                                              -1).nonzero(as_tuple=False)\n                    batch_mask_targets = mask_targets.view(\n                        batch_size, -1, H, W).bool()\n                    for i in range(batch_size):\n                        curr_inds = (rank_inds[:, 0] == i)\n                        curr_rank = rank_inds[:, 1][curr_inds]\n                        for j in curr_rank:\n                            rank_target[i][batch_mask_targets[i][j]] = j\n                    losses['loss_rank'] = self.loss_rank(\n                        mask_pred, rank_target, ignore_index=self.ignore_label)\n            else:\n                losses['loss_mask'] = mask_pred.sum() * 0\n                losses['loss_dice'] = mask_pred.sum() * 0\n                if self.loss_rank is not None:\n                    losses['loss_rank'] = mask_pred.sum() * 0\n\n        return losses\n\n    def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,\n                           pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,\n                           cfg):\n\n        num_pos = pos_mask.size(0)\n        num_neg = neg_mask.size(0)\n        num_samples = num_pos + num_neg\n        H, W = pos_mask.shape[-2:]\n        # original implementation uses new_zeros since BG are set to be 0\n        # now use empty & fill because BG cat_id = num_classes,\n        # FG cat_id = [0, num_classes-1]\n        labels = pos_mask.new_full((num_samples, ),\n                                   self.num_classes,\n                                   dtype=torch.long)\n        label_weights = pos_mask.new_zeros((num_samples, self.num_classes))\n        mask_targets = pos_mask.new_zeros(num_samples, H, W)\n        mask_weights = pos_mask.new_zeros(num_samples, H, W)\n        if num_pos > 0:\n            labels[pos_inds] = pos_gt_labels\n            pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight\n            label_weights[pos_inds] = pos_weight\n            pos_mask_targets = pos_gt_mask\n            mask_targets[pos_inds, ...] = pos_mask_targets\n            mask_weights[pos_inds, ...] = 1\n\n        if num_neg > 0:\n            label_weights[neg_inds] = 1.0\n\n        if gt_sem_cls is not None and gt_sem_seg is not None:\n            sem_labels = pos_mask.new_full((self.num_stuff_classes, ),\n                                           self.num_classes,\n                                           dtype=torch.long)\n            sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W)\n            sem_stuff_weights = torch.eye(\n                self.num_stuff_classes, device=pos_mask.device)\n            sem_thing_weights = pos_mask.new_zeros(\n                (self.num_stuff_classes, self.num_thing_classes))\n            sem_label_weights = torch.cat(\n                [sem_thing_weights, sem_stuff_weights], dim=-1)\n            if len(gt_sem_cls > 0):\n                sem_inds = gt_sem_cls - self.num_thing_classes\n                sem_inds = sem_inds.long()\n                sem_labels[sem_inds] = gt_sem_cls.long()\n                sem_targets[sem_inds] = gt_sem_seg\n                sem_weights[sem_inds] = 1\n\n            label_weights[:, self.num_thing_classes:] = 0\n            labels = torch.cat([labels, sem_labels])\n            label_weights = torch.cat([label_weights, sem_label_weights])\n            mask_targets = torch.cat([mask_targets, sem_targets])\n            mask_weights = torch.cat([mask_weights, sem_weights])\n\n        return labels, label_weights, mask_targets, mask_weights\n\n    def get_targets(self,\n                    sampling_results,\n                    rcnn_train_cfg,\n                    concat=True,\n                    gt_sem_seg=None,\n                    gt_sem_cls=None):\n        num_imgs = len(sampling_results)\n        pos_inds_list = [res.pos_inds for res in sampling_results]\n        neg_inds_list = [res.neg_inds for res in sampling_results]\n        pos_mask_list = [res.pos_masks for res in sampling_results]\n        neg_mask_list = [res.neg_masks for res in sampling_results]\n        pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]\n        pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]\n        if gt_sem_seg is None:\n            gt_sem_seg = [None] * num_imgs\n            gt_sem_cls = [None] * num_imgs\n\n        labels, label_weights, mask_targets, mask_weights = multi_apply(\n            self._get_target_single,\n            pos_inds_list,\n            neg_inds_list,\n            pos_mask_list,\n            neg_mask_list,\n            pos_gt_mask_list,\n            pos_gt_labels_list,\n            gt_sem_seg,\n            gt_sem_cls,\n            cfg=rcnn_train_cfg)\n        if concat:\n            labels = torch.cat(labels, 0)\n            label_weights = torch.cat(label_weights, 0)\n            mask_targets = torch.cat(mask_targets, 0)\n            mask_weights = torch.cat(mask_weights, 0)\n        return labels, label_weights, mask_targets, mask_weights\n\n    def rescale_masks(self, masks_per_img, img_meta):\n        h, w, _ = img_meta['img_shape']\n        masks_per_img = F.interpolate(\n            masks_per_img.unsqueeze(0).sigmoid(),\n            size=img_meta['batch_input_shape'],\n            mode='bilinear',\n            align_corners=False)\n\n        masks_per_img = masks_per_img[:, :, :h, :w]\n        ori_shape = img_meta['ori_shape']\n        seg_masks = F.interpolate(\n            masks_per_img,\n            size=ori_shape[:2],\n            mode='bilinear',\n            align_corners=False).squeeze(0)\n        return seg_masks\n\n    def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img,\n                      test_cfg, img_meta):\n        # resize mask predictions back\n        seg_masks = self.rescale_masks(masks_per_img, img_meta)\n        seg_masks = seg_masks > test_cfg.mask_thr\n        bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img,\n                                                    scores_per_img)\n        return bbox_result, segm_result\n\n    def segm2result(self, mask_preds, det_labels, cls_scores):\n        num_classes = self.num_classes\n        bbox_result = None\n        segm_result = [[] for _ in range(num_classes)]\n        mask_preds = mask_preds.cpu().numpy()\n        det_labels = det_labels.cpu().numpy()\n        cls_scores = cls_scores.cpu().numpy()\n        num_ins = mask_preds.shape[0]\n        # fake bboxes\n        bboxes = np.zeros((num_ins, 5), dtype=np.float32)\n        bboxes[:, -1] = cls_scores\n        bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)]\n        for idx in range(num_ins):\n            segm_result[det_labels[idx]].append(mask_preds[idx])\n        return bbox_result, segm_result\n\n    def get_seg_masks_tracking(self, masks_per_img, labels_per_img, scores_per_img, ids_per_img,\n                      test_cfg, img_meta):\n        num_ins = masks_per_img.shape[0]\n        # resize mask predictions back\n        seg_masks = self.rescale_masks(masks_per_img, img_meta)\n        seg_masks = seg_masks > test_cfg.mask_thr\n        # fake bboxes\n        bboxes = torch.zeros((num_ins, 5), dtype=torch.float32)\n        bboxes[:, -1] = scores_per_img\n        tracks = outs2results(\n            bboxes=bboxes,\n            labels=labels_per_img,\n            masks=seg_masks,\n            ids=ids_per_img,\n            num_classes=self.num_classes,\n        )\n        return tracks['bbox_results'], tracks['mask_results']\n"
  },
  {
    "path": "knet_vis/tracker/mask_hungarian_assigner.py",
    "content": "import numpy as np\nimport torch\n\nfrom mmdet.core import AssignResult, BaseAssigner\nfrom mmdet.core.bbox.builder import BBOX_ASSIGNERS\nfrom mmdet.core.bbox.match_costs.builder import build_match_cost\n\ntry:\n    from scipy.optimize import linear_sum_assignment\nexcept ImportError:\n    linear_sum_assignment = None\n\n\n\n\n@BBOX_ASSIGNERS.register_module()\nclass MaskHungarianAssignerVideo(BaseAssigner):\n    \"\"\"Computes one-to-one matching between predictions and ground truth.\n\n    This class computes an assignment between the targets and the predictions\n    based on the costs. The costs are weighted sum of three components:\n    classfication cost, regression L1 cost and regression iou cost. The\n    targets don't include the no_object, so generally there are more\n    predictions than targets. After the one-to-one matching, the un-matched\n    are treated as backgrounds. Thus each query prediction will be assigned\n    with `0` or a positive integer indicating the ground truth index:\n\n    - 0: negative sample, no assigned gt\n    - positive integer: positive sample, index (1-based) of assigned gt\n\n    Args:\n        cls_weight (int | float, optional): The scale factor for classification\n            cost. Default 1.0.\n        bbox_weight (int | float, optional): The scale factor for regression\n            L1 cost. Default 1.0.\n        iou_weight (int | float, optional): The scale factor for regression\n            iou cost. Default 1.0.\n        iou_calculator (dict | optional): The config for the iou calculation.\n            Default type `BboxOverlaps2D`.\n        iou_mode (str | optional): \"iou\" (intersection over union), \"iof\"\n                (intersection over foreground), or \"giou\" (generalized\n                intersection over union). Default \"giou\".\n    \"\"\"\n\n    def __init__(self,\n                 cls_cost=dict(type='ClassificationCost', weight=1.),\n                 mask_cost=dict(type='SigmoidCost', weight=1.0),\n                 dice_cost=dict(),\n                 boundary_cost=None,\n                 topk=1):\n        self.cls_cost = build_match_cost(cls_cost)\n        self.mask_cost = build_match_cost(mask_cost)\n        self.dice_cost = build_match_cost(dice_cost)\n        if boundary_cost is not None:\n            self.boundary_cost = build_match_cost(boundary_cost)\n        else:\n            self.boundary_cost = None\n        self.topk = topk\n\n    def assign(self,\n               bbox_pred,\n               cls_pred,\n               gt_bboxes,\n               gt_labels,\n               gt_instance_ids,\n               img_meta=None,\n               gt_bboxes_ignore=None,\n               eps=1e-7):\n        \"\"\"Computes one-to-one matching based on the weighted costs.\n\n        This method assign each query prediction to a ground truth or\n        background. The `assigned_gt_inds` with -1 means don't care,\n        0 means negative sample, and positive number is the index (1-based)\n        of assigned gt.\n        The assignment is done in the following steps, the order matters.\n\n        1. assign every prediction to -1\n        2. compute the weighted costs\n        3. do Hungarian matching on CPU based on the costs\n        4. assign all to 0 (background) first, then for each matched pair\n           between predictions and gts, treat this prediction as foreground\n           and assign the corresponding gt index (plus 1) to it.\n\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            cls_pred (Tensor): Predicted classification logits, shape\n                [num_query, num_class].\n            gt_bboxes (Tensor): Ground truth boxes with unnormalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).\n            img_meta (dict): Meta information for current image.\n            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are\n                labelled as `ignored`. Default None.\n            eps (int | float, optional): A value added to the denominator for\n                numerical stability. Default 1e-7.\n\n        Returns:\n            :obj:`AssignResult`: The assigned result.\n        \"\"\"\n        assert gt_bboxes_ignore is None, \\\n            'Only case when gt_bboxes_ignore is None is supported.'\n        instances = torch.unique(gt_instance_ids[:,1])\n        num_frames = bbox_pred.size(0)\n        h, w = bbox_pred.shape[-2:]\n        gt_masks = []\n        gt_labels_tensor =[]\n        for instance_id in instances:\n            gt_instance_frame_ids = gt_instance_ids[gt_instance_ids[:, 1] == instance_id, 0]\n            instance_masks = []\n            gt_label_id = None\n            for frame_id in range(num_frames):\n                gt_frame_instance_ids = gt_instance_ids[gt_instance_ids[:,0] == frame_id, 1]\n                gt_frame_label_ids = gt_labels[gt_labels[:,0] == frame_id, 1]\n                assert len(gt_frame_label_ids) == len(gt_frame_label_ids)\n                if not (frame_id in gt_instance_frame_ids):\n                    gt_mask_frame = torch.zeros((h, w), device=gt_instance_frame_ids.device, dtype=torch.float)\n                else:\n                    gt_index = torch.nonzero((gt_frame_instance_ids == instance_id), as_tuple=True)[0].item()\n                    gt_mask_frame = gt_bboxes[frame_id][gt_index]\n                    gt_label_id = gt_frame_label_ids[gt_index].item() if gt_label_id is None else gt_label_id\n                    assert gt_label_id == gt_frame_label_ids[gt_index].item()\n                instance_masks.append(gt_mask_frame)\n            gt_masks.append(torch.stack(instance_masks))\n            gt_labels_tensor.append(gt_label_id)\n        gt_masks = torch.stack(gt_masks)\n        gt_labels_tensor = torch.tensor(gt_labels_tensor, device=gt_masks.device, dtype=torch.long)\n\n\n        num_gts, num_bboxes = len(instances), bbox_pred.size(1)\n\n        # 1. assign -1 by default\n        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long)\n        assigned_labels = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long)\n        if num_gts == 0 or num_bboxes == 0:\n            # No ground truth or boxes, return empty assignment\n            if num_gts == 0:\n                # No ground truth, assign all to background\n                assigned_gt_inds[:] = 0\n            return AssignResult(\n                num_gts, assigned_gt_inds, None, labels=assigned_labels)\n\n        # 2. compute the weighted costs\n        # classification and bboxcost.\n        pred_masks_match = torch.einsum('fqhw->qfhw', bbox_pred).reshape((num_bboxes, -1, w))\n        gt_masks_match = gt_masks.reshape((num_gts, -1, w))\n        if self.cls_cost.weight != 0 and cls_pred is not None:\n            cls_cost = self.cls_cost(cls_pred, gt_labels_tensor)\n        else:\n            cls_cost = 0\n        if self.mask_cost.weight != 0:\n            reg_cost = self.mask_cost(pred_masks_match, gt_masks_match)\n        else:\n            reg_cost = 0\n        if self.dice_cost.weight != 0:\n            dice_cost = self.dice_cost(pred_masks_match, gt_masks_match)\n        else:\n            dice_cost = 0\n        if self.boundary_cost is not None and self.boundary_cost.weight != 0:\n            b_cost = self.boundary_cost(pred_masks_match, gt_masks_match)\n        else:\n            b_cost = 0\n        cost = cls_cost + reg_cost + dice_cost + b_cost\n\n        # 3. do Hungarian matching on CPU using linear_sum_assignment\n        cost = cost.detach().cpu()\n        if linear_sum_assignment is None:\n            raise ImportError('Please run \"pip install scipy\" '\n                              'to install scipy first.')\n        if self.topk == 1:\n            matched_row_inds, matched_col_inds = linear_sum_assignment(cost)\n        else:\n            topk_matched_row_inds = []\n            topk_matched_col_inds = []\n            for i in range(self.topk):\n                matched_row_inds, matched_col_inds = linear_sum_assignment(\n                    cost)\n                topk_matched_row_inds.append(matched_row_inds)\n                topk_matched_col_inds.append(matched_col_inds)\n                cost[matched_row_inds] = 1e10\n            matched_row_inds = np.concatenate(topk_matched_row_inds)\n            matched_col_inds = np.concatenate(topk_matched_col_inds)\n\n        matched_row_inds = torch.from_numpy(matched_row_inds).to(\n            bbox_pred.device)\n        matched_col_inds = torch.from_numpy(matched_col_inds).to(\n            bbox_pred.device)\n\n        # 4. assign backgrounds and foregrounds\n        # assign all indices to backgrounds first\n        assigned_gt_inds[:] = 0\n        # assign foregrounds based on matching results\n        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1\n        assigned_labels[matched_row_inds] = gt_labels_tensor[matched_col_inds]\n        return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels), gt_masks_match\n"
  },
  {
    "path": "knet_vis/tracker/positional_encoding.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py\n\"\"\"\nVarious positional encodings for the transformer.\n\"\"\"\nimport math\n\nimport torch\n\nfrom mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING\nfrom mmcv.runner import BaseModule\n\n\n@POSITIONAL_ENCODING.register_module()\nclass PositionEmbeddingSine3D(BaseModule):\n    \"\"\"\n    This is a more standard version of the position embedding, very similar to the one\n    used by the Attention is all you need paper, generalized to work on images.\n    \"\"\"\n\n    def __init__(self, num_feats=64, temperature=10000, normalize=False, scale=None):\n        super().__init__()\n        self.num_pos_feats = num_feats\n        self.temperature = temperature\n        self.normalize = normalize\n        if scale is not None and normalize is False:\n            raise ValueError(\"normalize should be True if scale is passed\")\n        if scale is None:\n            scale = 2 * math.pi\n        self.scale = scale\n\n    def forward(self, x, mask=None):\n        # b, t, c, h, w\n        assert x.dim() == 5, f\"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead\"\n        if mask is None:\n            mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)\n        not_mask = ~mask\n        z_embed = not_mask.cumsum(1, dtype=torch.float32)\n        y_embed = not_mask.cumsum(2, dtype=torch.float32)\n        x_embed = not_mask.cumsum(3, dtype=torch.float32)\n        if self.normalize:\n            eps = 1e-6\n            z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale\n            y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale\n            x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale\n\n        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)\n        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n        dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)\n        dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))\n\n        pos_x = x_embed[:, :, :, :, None] / dim_t\n        pos_y = y_embed[:, :, :, :, None] / dim_t\n        pos_z = z_embed[:, :, :, :, None] / dim_t_z\n        pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)\n        pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)\n        pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)\n        pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)  # b, t, c, h, w\n        return pos\n"
  },
  {
    "path": "knet_vis/tracker/semantic_fpn_wrapper3D.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, normal_init\nfrom mmdet.models.builder import NECKS\nfrom mmcv.cnn.bricks.transformer import build_positional_encoding\nfrom mmdet.utils import get_root_logger\n\n\n@NECKS.register_module()\nclass SemanticFPNWrapper3D(nn.Module):\n    \"\"\"Implementation of Semantic FPN used in Panoptic FPN.\n\n    Args:\n        in_channels ([type]): [description]\n        feat_channels ([type]): [description]\n        out_channels ([type]): [description]\n        start_level ([type]): [description]\n        end_level ([type]): [description]\n        cat_coors (bool, optional): [description]. Defaults to False.\n        fuse_by_cat (bool, optional): [description]. Defaults to False.\n        conv_cfg ([type], optional): [description]. Defaults to None.\n        norm_cfg ([type], optional): [description]. Defaults to None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 feat_channels,\n                 out_channels,\n                 start_level,\n                 end_level,\n                 cat_coors=False,\n                 positional_encoding=None,\n                 cat_coors_level=3,\n                 fuse_by_cat=False,\n                 return_list=False,\n                 upsample_times=3,\n                 with_pred=True,\n                 num_aux_convs=0,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 out_act_cfg=dict(type='ReLU'),\n                 conv_cfg=None,\n                 norm_cfg=None):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.feat_channels = feat_channels\n        self.start_level = start_level\n        self.end_level = end_level\n        assert start_level >= 0 and end_level >= start_level\n        self.out_channels = out_channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.cat_coors = cat_coors\n        self.cat_coors_level = cat_coors_level\n        self.fuse_by_cat = fuse_by_cat\n        self.return_list = return_list\n        self.upsample_times = upsample_times\n        self.with_pred = with_pred\n        if positional_encoding is not None:\n            self.positional_encoding = build_positional_encoding(\n                positional_encoding)\n        else:\n            self.positional_encoding = None\n\n        self.convs_all_levels = nn.ModuleList()\n        for i in range(self.start_level, self.end_level + 1):\n            convs_per_level = nn.Sequential()\n            if i == 0:\n                if i == self.cat_coors_level and self.cat_coors:\n                    chn = self.in_channels + 2\n                else:\n                    chn = self.in_channels\n                if upsample_times == self.end_level - i:\n                    one_conv = ConvModule(\n                        chn,\n                        self.feat_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg,\n                        inplace=False)\n                    convs_per_level.add_module('conv' + str(i), one_conv)\n                else:\n                    for i in range(self.end_level - upsample_times):\n                        one_conv = ConvModule(\n                            chn,\n                            self.feat_channels,\n                            3,\n                            padding=1,\n                            stride=2,\n                            conv_cfg=self.conv_cfg,\n                            norm_cfg=self.norm_cfg,\n                            act_cfg=self.act_cfg,\n                            inplace=False)\n                        convs_per_level.add_module('conv' + str(i), one_conv)\n                self.convs_all_levels.append(convs_per_level)\n                continue\n\n            for j in range(i):\n                if j == 0:\n                    if i == self.cat_coors_level and self.cat_coors:\n                        chn = self.in_channels + 2\n                    else:\n                        chn = self.in_channels\n                    one_conv = ConvModule(\n                        chn,\n                        self.feat_channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg,\n                        inplace=False)\n                    convs_per_level.add_module('conv' + str(j), one_conv)\n                    if j < upsample_times - (self.end_level - i):\n                        one_upsample = nn.Upsample(\n                            scale_factor=2,\n                            mode='bilinear',\n                            align_corners=False)\n                        convs_per_level.add_module('upsample' + str(j),\n                                                   one_upsample)\n                    continue\n\n                one_conv = ConvModule(\n                    self.feat_channels,\n                    self.feat_channels,\n                    3,\n                    padding=1,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg,\n                    inplace=False)\n                convs_per_level.add_module('conv' + str(j), one_conv)\n                if j < upsample_times - (self.end_level - i):\n                    one_upsample = nn.Upsample(\n                        scale_factor=2, mode='bilinear', align_corners=False)\n                    convs_per_level.add_module('upsample' + str(j),\n                                               one_upsample)\n\n            self.convs_all_levels.append(convs_per_level)\n\n        if fuse_by_cat:\n            in_channels = self.feat_channels * len(self.convs_all_levels)\n        else:\n            in_channels = self.feat_channels\n\n        if self.with_pred:\n            self.conv_pred = ConvModule(\n                in_channels,\n                self.out_channels,\n                1,\n                padding=0,\n                conv_cfg=self.conv_cfg,\n                act_cfg=out_act_cfg,\n                norm_cfg=self.norm_cfg)\n\n        self.num_aux_convs = num_aux_convs\n        self.aux_convs = nn.ModuleList()\n        for i in range(num_aux_convs):\n            self.aux_convs.append(\n                ConvModule(\n                    in_channels,\n                    self.out_channels,\n                    1,\n                    padding=0,\n                    conv_cfg=self.conv_cfg,\n                    act_cfg=out_act_cfg,\n                    norm_cfg=self.norm_cfg))\n\n    def init_weights(self):\n        logger = get_root_logger()\n        logger.info('Use normal intialization for semantic FPN')\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                normal_init(m, std=0.01)\n\n    def generate_coord(self, input_feat):\n        x_range = torch.linspace(\n            -1, 1, input_feat.shape[-1], device=input_feat.device)\n        y_range = torch.linspace(\n            -1, 1, input_feat.shape[-2], device=input_feat.device)\n        y, x = torch.meshgrid(y_range, x_range)\n        y = y.expand([input_feat.shape[0], 1, -1, -1])\n        x = x.expand([input_feat.shape[0], 1, -1, -1])\n        coord_feat = torch.cat([x, y], 1)\n        return coord_feat\n\n    def forward(self, inputs, num_imgs, num_frames):\n        mlvl_feats = []\n        for i in range(self.start_level, self.end_level + 1):\n            input_p = inputs[i]\n            if i == self.cat_coors_level:\n                if self.positional_encoding is not None:\n                    input_p = input_p.view(num_imgs, num_frames, *input_p.size()[1:])\n                    assert self.positional_encoding.__class__.__name__.endswith('3D')\n                    positional_encoding = self.positional_encoding(input_p)\n                    input_p = (input_p + positional_encoding).reshape(num_imgs * num_frames, *input_p.size()[2:])\n                if self.cat_coors:\n                    coord_feat = self.generate_coord(input_p)\n                    input_p = torch.cat([input_p, coord_feat], 1)\n\n            mlvl_feats.append(self.convs_all_levels[i](input_p))\n\n        if self.fuse_by_cat:\n            feature_add_all_level = torch.cat(mlvl_feats, dim=1)\n        else:\n            feature_add_all_level = sum(mlvl_feats)\n\n        if self.with_pred:\n            out = self.conv_pred(feature_add_all_level)\n        else:\n            out = feature_add_all_level\n\n        if self.num_aux_convs > 0:\n            outs = [out]\n            for conv in self.aux_convs:\n                outs.append(conv(feature_add_all_level))\n            return outs\n\n        if self.return_list:\n            return [out]\n        else:\n            return out\n"
  },
  {
    "path": "knet_vis/tracker/track.py",
    "content": "import copy\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mmdet.models.builder import DETECTORS\nfrom mmdet.models.detectors import TwoStageDetector\nfrom mmdet.utils import get_root_logger\nfrom mmdet.models import build_head\n\nfrom knet_vis.det.utils import sem2ins_masks\n\n\n@DETECTORS.register_module()\nclass KNetTrack(TwoStageDetector):\n\n    def __init__(self,\n                 *args,\n                 num_thing_classes=80,\n                 num_stuff_classes=53,\n                 mask_assign_stride=4,\n                 thing_label_in_seg=0,\n                 direct_tracker=False,\n                 tracker_num=1,\n                 tracker=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 **kwargs):\n        self.roi_head = None # init roi_head with None\n        super().__init__(*args, **kwargs, train_cfg=train_cfg, test_cfg=test_cfg)\n        assert self.with_rpn, 'KNet does not support external proposals'\n        self.num_thing_classes = num_thing_classes\n        self.num_stuff_classes = num_stuff_classes\n        self.mask_assign_stride = mask_assign_stride\n        self.thing_label_in_seg = thing_label_in_seg\n        self.direct_tracker = direct_tracker\n        self.tracker_num = tracker_num\n        if tracker is not None:\n            rcnn_train_cfg = train_cfg.tracker if train_cfg is not None else None\n            tracker.update(train_cfg=rcnn_train_cfg)\n            tracker.update(test_cfg=test_cfg.tracker)\n            self.tracker = build_head(tracker)\n            if self.tracker_num > 1:\n                self.tracker_extra = nn.ModuleList(\n                    [build_head(tracker) for _ in range(tracker_num - 1)]\n                )\n        logger = get_root_logger()\n        logger.info(f'Model: \\n{self}')\n\n\n    def gt_transform(self, img_metas, gt_masks, gt_labels, gt_semantic_seg):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        gt_masks_tensor = []\n        gt_sem_seg = []\n        gt_sem_cls = []\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for i, gt_mask in enumerate(gt_masks):\n            mask_tensor = gt_mask.to_tensor(torch.float, gt_labels[0].device)\n            if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n            if gt_semantic_seg is not None:\n                # gt_semantic seg is padded by 255 and\n                # zero indicating the first class\n                sem_labels, sem_seg = sem2ins_masks(\n                    gt_semantic_seg[i],\n                    num_thing_classes=self.num_thing_classes)\n                if sem_seg.shape[0] == 0:\n                    gt_sem_seg.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W))\n                    )\n                else:\n                    gt_sem_seg.append(\n                        F.interpolate(\n                            sem_seg[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0]\n                    )\n                gt_sem_cls.append(sem_labels)\n\n            else:\n                gt_sem_seg = None\n                gt_sem_cls = None\n\n            if mask_tensor.shape[0] == 0:\n                gt_masks_tensor.append(\n                    mask_tensor.new_zeros(\n                        (mask_tensor.size(0), assign_H, assign_W))\n                )\n            else:\n                gt_masks_tensor.append(\n                    F.interpolate(\n                        mask_tensor[None], (assign_H, assign_W),\n                        mode='bilinear',\n                        align_corners=False)[0]\n                )\n        return gt_masks_tensor, gt_sem_seg, gt_sem_cls\n\n    def ref_gt_transform(self, ref_img_metas, ref_gt_masks, ref_gt_labels, ref_gt_semantic_seg=None ):\n        # gt_masks and gt_semantic_seg are not padded when forming batch\n        ref_gt_masks_tensor = []\n        assert ref_gt_semantic_seg is None\n        ref_gt_sem_seg = None\n        ref_gt_sem_cls = None\n        # batch_input_shape shoud be the same across images\n        pad_H, pad_W = ref_img_metas[0]['batch_input_shape']\n        assign_H = pad_H // self.mask_assign_stride\n        assign_W = pad_W // self.mask_assign_stride\n\n        for bs_i, gt_mask_frame in enumerate(ref_gt_masks):\n            batch_cur_gt_masks_tensor = []\n            for i, gt_mask in enumerate(gt_mask_frame):\n                mask_tensor = gt_mask.to_tensor(torch.float, ref_gt_labels[bs_i].device)\n                if gt_mask.width != pad_W or gt_mask.height != pad_H:\n                    pad_wh = (0, pad_W - gt_mask.width, 0, pad_H - gt_mask.height)\n                    mask_tensor = F.pad(mask_tensor, pad_wh, value=0)\n\n                if mask_tensor.shape[0] == 0:\n                    batch_cur_gt_masks_tensor.append(\n                        mask_tensor.new_zeros(\n                            (mask_tensor.size(0), assign_H, assign_W))\n                    )\n                else:\n                    batch_cur_gt_masks_tensor.append(\n                        F.interpolate(\n                            mask_tensor[None], (assign_H, assign_W),\n                            mode='bilinear',\n                            align_corners=False)[0]\n                    )\n            ref_gt_masks_tensor.append(batch_cur_gt_masks_tensor)\n\n        return ref_gt_masks_tensor, ref_gt_sem_seg, ref_gt_sem_cls\n\n\n    def forward_train(self,\n                      img,\n                      img_metas,\n                      gt_bboxes=None,\n                      gt_labels=None,\n                      gt_bboxes_ignore=None,\n                      gt_masks=None,\n                      proposals=None,\n                      gt_semantic_seg=None,\n                      gt_instance_ids=None,\n                      # references\n                      ref_img=None,\n                      ref_img_metas=None,\n                      ref_gt_bboxes=None,\n                      ref_gt_labels=None,\n                      ref_gt_bboxes_ignore=None,\n                      ref_gt_masks=None,\n                      ref_gt_instance_ids=None,\n                      **kwargs):\n\n        super(TwoStageDetector, self).forward_train(img, img_metas)\n        assert proposals is None, 'KNet does not support external proposals'\n        assert gt_masks is not None\n\n        ref_gt_masks, ref_gt_sem_seg, ref_gt_sem_cls  = \\\n            self.ref_gt_transform(img_metas, ref_gt_masks, ref_gt_labels, ref_gt_semantic_seg=None)\n        bs, num_frame, _, h, w = ref_img.size()\n        x = self.extract_feat(ref_img.reshape(bs * num_frame, _, h, w))\n\n        losses = dict()\n\n        rpn_losses, proposal_feats, x_feats, mask_preds, cls_scores = \\\n            self.rpn_head.forward_train(x, img_metas, ref_img_metas, ref_gt_masks, ref_gt_labels,\n                                        ref_gt_instance_ids, ref_gt_sem_seg, ref_gt_sem_cls)\n        losses.update(rpn_losses)\n\n        if self.roi_head is not None:\n            roi_losses, features = self.roi_head.forward_train(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                ref_img_metas,\n                ref_gt_masks,\n                ref_gt_labels,\n                gt_bboxes_ignore=ref_gt_bboxes_ignore,\n                gt_bboxes=ref_gt_bboxes,\n                gt_sem_seg=ref_gt_sem_seg,\n                gt_sem_cls=ref_gt_sem_cls,\n                imgs_whwh=None)\n            losses.update(roi_losses)\n\n        if self.direct_tracker:\n            proposal_feats = self.rpn_head.init_kernels.weight.clone()\n            proposal_feats = proposal_feats[None].expand(bs, *proposal_feats.size())\n            if mask_preds.shape[0] == bs * num_frame:\n                mask_preds = mask_preds.reshape((bs, num_frame, *mask_preds.size()[1:]))\n                x_feats = x_feats.reshape((bs, num_frame, *x_feats.size()[1:]))\n            else:\n                assert mask_preds.size()[:2] == (bs, num_frame)\n                assert x_feats.size()[:2] == (bs, num_frame)\n\n            tracker_losses, features = self.tracker.forward_train(\n                x=x_feats,\n                ref_img_metas=ref_img_metas,\n                cls_scores=None,\n                masks=mask_preds,\n                obj_feats=proposal_feats,\n                ref_gt_masks=ref_gt_masks,\n                ref_gt_labels=ref_gt_labels,\n                ref_gt_instance_ids=ref_gt_instance_ids,\n            )\n            if self.tracker_num > 1:\n                for i in range(self.tracker_num - 1):\n                    _tracker_losses, features = self.tracker_extra[i].forward_train(\n                        x=features['x_feats'],\n                        ref_img_metas=ref_img_metas,\n                        cls_scores=None,\n                        masks=features['masks'],\n                        obj_feats=features['obj_feats'],\n                        ref_gt_masks=ref_gt_masks,\n                        ref_gt_labels=ref_gt_labels,\n                        ref_gt_instance_ids=ref_gt_instance_ids,\n                    )\n                    for key, value in _tracker_losses.items():\n                        tracker_losses[f'extra_m{i}_{key}'] = value\n        else:\n            tracker_losses, _ = self.tracker.forward_train(\n                x=features['x_feats'],\n                ref_img_metas=ref_img_metas,\n                cls_scores=features['cls_scores'],\n                masks=features['masks'],\n                obj_feats=features['obj_feats'],\n                ref_gt_masks=ref_gt_masks,\n                ref_gt_labels=ref_gt_labels,\n                ref_gt_instance_ids=ref_gt_instance_ids,\n            )\n\n        losses.update(tracker_losses)\n        return losses\n\n    def forward_test(self, imgs, img_metas, **kwargs):\n        \"\"\"\n        Args:\n            imgs (List[Tensor]): the outer list indicates test-time\n                augmentations and inner Tensor should have a shape NxCxHxW,\n                which contains all images in the batch.\n            img_metas (List[List[dict]]): the outer list indicates test-time\n                augs (multiscale, flip, etc.) and the inner list indicates\n                images in a batch.\n        \"\"\"\n        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:\n            if not isinstance(var, list):\n                raise TypeError(f'{name} must be a list, but got {type(var)}')\n\n        num_augs = len(imgs)\n        if num_augs != len(img_metas):\n            raise ValueError(f'num of augmentations ({len(imgs)}) '\n                             f'!= num of image meta ({len(img_metas)})')\n\n        # NOTE the batched image size information may be useful, e.g.\n        # in DETR, this is needed for the construction of masks, which is\n        # then used for the transformer_head.\n        for img, img_meta in zip(imgs, img_metas):\n            batch_size = len(img_meta)\n            for img_id in range(batch_size):\n                img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])\n\n        if num_augs == 1:\n            # proposals (List[List[Tensor]]): the outer list indicates\n            # test-time augs (multiscale, flip, etc.) and the inner list\n            # indicates images in a batch.\n            # The Tensor should have a shape Px4, where P is the number of\n            # proposals.\n            if 'proposals' in kwargs:\n                kwargs['proposals'] = kwargs['proposals'][0]\n            kwargs['ref_img_metas'] = kwargs['ref_img_metas'][0]\n            kwargs['ref_img'] = kwargs['ref_img'][0]\n            return self.simple_test(imgs[0], img_metas[0], **kwargs)\n        else:\n            assert imgs[0].size(0) == 1, 'aug test does not support ' \\\n                                         'inference with batch size ' \\\n                                         f'{imgs[0].size(0)}'\n            # TODO: support test augmentation for predefined proposals\n            assert 'proposals' not in kwargs\n            return self.aug_test(imgs, img_metas, **kwargs)\n\n    def simple_test(self, imgs, img_metas, **kwargs):\n        ref_img = kwargs['ref_img']\n        ref_img_metas = kwargs['ref_img_metas']\n        # Step 1 extract features and get masks\n        bs, num_frame, _, h, w = ref_img.size()\n        x = self.extract_feat(ref_img.reshape(bs * num_frame, _, h, w))\n\n        proposal_feats, x_feats, mask_preds, cls_scores, seg_preds = \\\n            self.rpn_head.simple_test_rpn(x, img_metas, ref_img_metas)\n\n        if self.roi_head is not None:\n            segm_results_single_frame, features = self.roi_head.simple_test(\n                x_feats,\n                proposal_feats,\n                mask_preds,\n                cls_scores,\n                img_metas,\n                ref_img_metas,\n                imgs_whwh=None,\n                rescale=True\n            )\n\n        if self.direct_tracker:\n            proposal_feats = self.rpn_head.init_kernels.weight.clone()\n            proposal_feats = proposal_feats[None].expand(bs, *proposal_feats.size())\n            if mask_preds.shape[0] == bs * num_frame:\n                mask_preds = mask_preds.reshape((bs, num_frame, *mask_preds.size()[1:]))\n                x_feats = x_feats.reshape((bs, num_frame, *x_feats.size()[1:]))\n            else:\n                assert mask_preds.size()[:2] == (bs, num_frame)\n                assert x_feats.size()[:2] == (bs, num_frame)\n            segm_results, features = self.tracker.simple_test(\n                x=x_feats,\n                img_metas=img_metas,\n                ref_img_metas=ref_img_metas,\n                cls_scores=None,\n                masks=mask_preds,\n                obj_feats=proposal_feats,\n            )\n            if self.tracker_num > 1:\n                for i in range(self.tracker_num - 1):\n                    segm_results, features = self.tracker_extra[i].simple_test(\n                        x=features['x_feats'],\n                        img_metas=img_metas,\n                        ref_img_metas=ref_img_metas,\n                        cls_scores=None,\n                        masks=features['masks'],\n                        obj_feats=features['obj_feats'],\n                    )\n        else:\n            segm_results, _ = self.tracker.simple_test(\n                x=features['x_feats'],\n                img_metas=img_metas,\n                ref_img_metas=ref_img_metas,\n                cls_scores=features['cls_scores'],\n                masks=features['masks'],\n                obj_feats=features['obj_feats'],\n            )\n\n        return segm_results\n\n    def forward_dummy(self, img):\n        \"\"\"Used for computing network flops.\n\n        See `mmdetection/tools/get_flops.py`\n        \"\"\"\n        # backbone\n        x = self.extract_feat(img)\n        # rpn\n        num_imgs = len(img)\n        dummy_img_metas = [\n            dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)\n        ]\n        rpn_results = self.rpn_head.simple_test_rpn(x, dummy_img_metas)\n        (proposal_feats, x_feats, mask_preds, cls_scores,\n         seg_preds) = rpn_results\n        # roi_head\n        roi_outs = self.roi_head.forward_dummy(x_feats, proposal_feats, dummy_img_metas)\n        return roi_outs\n\n    def init_weights(self):\n        super().init_weights()\n        if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained':\n            assert self.tracker.init_cfg is None\n            self.tracker.init_cfg = copy.deepcopy(self.init_cfg)\n            self.tracker.init_cfg['prefix']='roi_head'\n            self.tracker.init_weights()\n            if self.tracker_num > 1:\n                for _ in range(self.tracker_num - 1):\n                    assert self.tracker_extra[_].init_cfg is None\n                    self.tracker_extra[_].init_cfg = copy.deepcopy(self.init_cfg)\n                    self.tracker_extra[_].init_cfg['prefix'] = 'roi_head'\n                    self.tracker_extra[_].init_weights()\n"
  },
  {
    "path": "mmtrack/datasets/coco_video_dataset.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport random\n\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom mmdet.datasets import DATASETS, CocoDataset\nfrom terminaltables import AsciiTable\n\nfrom mmdet.utils import get_root_logger\nfrom .parsers import CocoVID\n\n\n@DATASETS.register_module()\nclass CocoVideoDataset(CocoDataset):\n    \"\"\"Base coco video dataset for VID, MOT and SOT tasks.\n\n    Args:\n        load_as_video (bool): If True, using COCOVID class to load dataset,\n            otherwise, using COCO class. Default: True.\n        key_img_sampler (dict): Configuration of sampling key images.\n        ref_img_sampler (dict): Configuration of sampling ref images.\n        test_load_ann (bool): If True, loading annotations during testing,\n            otherwise, not loading. Default: False.\n    \"\"\"\n\n    CLASSES = None\n\n    def __init__(self,\n                 load_as_video=True,\n                 key_img_sampler=dict(interval=1),\n                 ref_img_sampler=dict(\n                     frame_range=10,\n                     stride=1,\n                     num_ref_imgs=1,\n                     filter_key_img=True,\n                     method='uniform',\n                     return_key_img=True),\n                 test_load_ann=False,\n                 load_all_frames=False,\n                 *args,\n                 **kwargs):\n        self.load_as_video = load_as_video\n        self.key_img_sampler = key_img_sampler\n        self.ref_img_sampler = ref_img_sampler\n        self.test_load_ann = test_load_ann\n        self.load_all_frames = load_all_frames\n        assert not (self.load_all_frames and ref_img_sampler is not None), \"load all frames indicate no sampler\"\n        super().__init__(*args, **kwargs)\n        self.logger = get_root_logger()\n\n    def load_annotations(self, ann_file):\n        \"\"\"Load annotations from COCO/COCOVID style annotation file.\n\n        Args:\n            ann_file (str): Path of annotation file.\n\n        Returns:\n            list[dict]: Annotation information from COCO/COCOVID api.\n        \"\"\"\n        if not self.load_as_video:\n            data_infos = super().load_annotations(ann_file)\n        else:\n            data_infos = self.load_video_anns(ann_file)\n        return data_infos\n\n    def load_video_anns(self, ann_file):\n        \"\"\"Load annotations from COCOVID style annotation file.\n\n        Args:\n            ann_file (str): Path of annotation file.\n\n        Returns:\n            list[dict]: Annotation information from COCOVID api.\n        \"\"\"\n        self.coco = CocoVID(ann_file)\n        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)\n        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}\n\n        data_infos = []\n        self.vid_ids = self.coco.get_vid_ids()\n        self.img_ids = [] if not self.load_all_frames else None\n        for vid_id in self.vid_ids:\n            img_ids = self.coco.get_img_ids_from_vid(vid_id)\n            if self.key_img_sampler is not None:\n                img_ids = self.key_img_sampling(img_ids,\n                                                **self.key_img_sampler)\n            if self.load_all_frames:\n                info = self.coco.load_imgs(img_ids)\n                info = [info[0], *info]\n                for item in info:\n                    item['filename'] = item['file_name']\n                data_infos.append(info)\n            else:\n                self.img_ids.extend(img_ids)\n                for img_id in img_ids:\n                    info = self.coco.load_imgs([img_id])[0]\n                    info['filename'] = info['file_name']\n                    data_infos.append(info)\n        return data_infos\n\n    def key_img_sampling(self, img_ids, interval=1):\n        \"\"\"Sampling key images.\"\"\"\n        return img_ids[::interval]\n\n    def ref_img_sampling(self,\n                         img_info,\n                         frame_range,\n                         stride=1,\n                         num_ref_imgs=1,\n                         filter_key_img=True,\n                         method='uniform',\n                         return_key_img=True):\n        \"\"\"Sampling reference frames in the same video for key frame.\n\n        Args:\n            img_info (dict): The information of key frame.\n            frame_range (List(int) | int): The sampling range of reference\n                frames in the same video for key frame.\n            stride (int): The sampling frame stride when sampling reference\n                images. Default: 1.\n            num_ref_imgs (int): The number of sampled reference images.\n                Default: 1.\n            filter_key_img (bool): If False, the key image will be in the\n                sampling reference candidates, otherwise, it is exclude.\n                Default: True.\n            method (str): The sampling method. Options are 'uniform',\n                'bilateral_uniform', 'test_with_adaptive_stride',\n                'test_with_fix_stride'. 'uniform' denotes reference images are\n                randomly sampled from the nearby frames of key frame.\n                'bilateral_uniform' denotes reference images are randomly\n                sampled from the two sides of the nearby frames of key frame.\n                'test_with_adaptive_stride' is only used in testing, and\n                denotes the sampling frame stride is equal to (video length /\n                the number of reference images). test_with_fix_stride is only\n                used in testing with sampling frame stride equalling to\n                `stride`. Default: 'uniform'.\n            return_key_img (bool): If True, the information of key frame is\n                returned, otherwise, not returned. Default: True.\n\n        Returns:\n            list(dict): `img_info` and the reference images information or\n            only the reference images information.\n        \"\"\"\n        assert isinstance(img_info, dict)\n        if isinstance(frame_range, int):\n            assert frame_range >= 0, 'frame_range can not be a negative value.'\n            frame_range = [-frame_range, frame_range]\n        elif isinstance(frame_range, list):\n            assert len(frame_range) == 2, 'The length must be 2.'\n            assert frame_range[0] <= 0 and frame_range[1] >= 0\n            for i in frame_range:\n                assert isinstance(i, int), 'Each element must be int.'\n        else:\n            raise TypeError('The type of frame_range must be int or list.')\n\n        if 'test' in method and \\\n                (frame_range[1] - frame_range[0]) != num_ref_imgs:\n            print_log(\n                'Warning:'\n                \"frame_range[1] - frame_range[0] isn't equal to num_ref_imgs.\"\n                'Set num_ref_imgs to frame_range[1] - frame_range[0].',\n                logger=self.logger)\n            self.ref_img_sampler[\n                'num_ref_imgs'] = frame_range[1] - frame_range[0]\n\n        if (not self.load_as_video) or img_info.get('frame_id', -1) < 0 \\\n                or (frame_range[0] == 0 and frame_range[1] == 0):\n            ref_img_infos = []\n            for i in range(num_ref_imgs):\n                ref_img_infos.append(img_info.copy())\n        else:\n            vid_id, img_id, frame_id = img_info['video_id'], img_info[\n                'id'], img_info['frame_id']\n            img_ids = self.coco.get_img_ids_from_vid(vid_id)\n            left = max(0, frame_id + frame_range[0])\n            right = min(frame_id + frame_range[1], len(img_ids) - 1)\n\n            ref_img_ids = []\n            if method == 'uniform':\n                valid_ids = img_ids[left:right + 1]\n                if filter_key_img and img_id in valid_ids:\n                    valid_ids.remove(img_id)\n\n                if num_ref_imgs != len(valid_ids):\n                    return None\n\n\n                num_samples = min(num_ref_imgs, len(valid_ids))\n                ref_img_ids.extend(random.sample(valid_ids, num_samples))\n            elif method == 'bilateral_uniform':\n                assert num_ref_imgs % 2 == 0, \\\n                    'only support load even number of ref_imgs.'\n                for mode in ['left', 'right']:\n                    if mode == 'left':\n                        valid_ids = img_ids[left:frame_id + 1]\n                    else:\n                        valid_ids = img_ids[frame_id:right + 1]\n                    if filter_key_img and img_id in valid_ids:\n                        valid_ids.remove(img_id)\n                    num_samples = min(num_ref_imgs // 2, len(valid_ids))\n                    sampled_inds = random.sample(valid_ids, num_samples)\n                    ref_img_ids.extend(sampled_inds)\n            elif method == 'test_with_adaptive_stride':\n                if frame_id == 0:\n                    stride = float(len(img_ids) - 1) / (num_ref_imgs - 1)\n                    for i in range(num_ref_imgs):\n                        ref_id = round(i * stride)\n                        ref_img_ids.append(img_ids[ref_id])\n            elif method == 'test_with_fix_stride':\n                if frame_id == 0:\n                    for i in range(frame_range[0], 1):\n                        ref_img_ids.append(img_ids[0])\n                    for i in range(1, frame_range[1] + 1):\n                        ref_id = min(round(i * stride), len(img_ids) - 1)\n                        ref_img_ids.append(img_ids[ref_id])\n                elif frame_id % stride == 0:\n                    ref_id = min(\n                        round(frame_id + frame_range[1] * stride),\n                        len(img_ids) - 1)\n                    ref_img_ids.append(img_ids[ref_id])\n                img_info['num_left_ref_imgs'] = abs(frame_range[0]) \\\n                    if isinstance(frame_range, list) else frame_range\n                img_info['frame_stride'] = stride\n            else:\n                raise NotImplementedError\n\n            ref_img_infos = []\n            for ref_img_id in ref_img_ids:\n                ref_img_info = self.coco.load_imgs([ref_img_id])[0]\n                ref_img_info['filename'] = ref_img_info['file_name']\n                ref_img_infos.append(ref_img_info)\n            ref_img_infos = sorted(ref_img_infos, key=lambda i: i['frame_id'])\n\n        if return_key_img:\n            return [img_info, *ref_img_infos]\n        else:\n            return ref_img_infos\n\n    def get_ann_info(self, img_info):\n        \"\"\"Get COCO annotations by the information of image.\n\n        Args:\n            img_info (int): Information of image.\n\n        Returns:\n            dict: Annotation information of `img_info`.\n        \"\"\"\n        img_id = img_info['id']\n        ann_ids = self.coco.get_ann_ids(img_ids=[img_id], cat_ids=self.cat_ids)\n        ann_info = self.coco.load_anns(ann_ids)\n        return self._parse_ann_info(img_info, ann_info)\n\n    def prepare_results(self, img_info):\n        \"\"\"Prepare results for image (e.g. the annotation information, ...).\"\"\"\n        results = dict(img_info=img_info)\n        if not self.test_mode or self.test_load_ann:\n            results['ann_info'] = self.get_ann_info(img_info)\n        if self.proposals is not None:\n            idx = self.img_ids.index(img_info['id'])\n            results['proposals'] = self.proposals[idx]\n\n        super().pre_pipeline(results)\n        results['is_video_data'] = self.load_as_video\n        return results\n\n    def prepare_data(self, idx):\n        \"\"\"Get data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Data and annotations after pipeline with new keys introduced\n            by pipeline.\n        \"\"\"\n        img_info = self.data_infos[idx]\n        if self.ref_img_sampler is not None:\n            img_infos = self.ref_img_sampling(img_info, **self.ref_img_sampler)\n            if img_infos is None:\n                return None\n            results = [\n                self.prepare_results(img_info) for img_info in img_infos\n            ]\n        elif self.load_all_frames:\n            results = [\n                self.prepare_results(_img_info) for _img_info in img_info\n            ]\n        else:\n            results = self.prepare_results(img_info)\n        return self.pipeline(results)\n\n    def prepare_train_img(self, idx):\n        \"\"\"Get training data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training data and annotations after pipeline with new keys\n            introduced by pipeline.\n        \"\"\"\n        return self.prepare_data(idx)\n\n    def prepare_test_img(self, idx):\n        \"\"\"Get testing data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Testing data after pipeline with new keys intorduced by\n            pipeline.\n        \"\"\"\n        return self.prepare_data(idx)\n\n    def _parse_ann_info(self, img_info, ann_info):\n        \"\"\"Parse bbox and mask annotations.\n\n        Args:\n            img_anfo (dict): Information of image.\n            ann_info (list[dict]): Annotation information of image.\n\n        Returns:\n            dict: A dict containing the following keys: bboxes, bboxes_ignore,\n            labels, instance_ids, masks, seg_map. \"masks\" are raw\n            annotations and not decoded into binary masks.\n        \"\"\"\n        gt_bboxes = []\n        gt_labels = []\n        gt_bboxes_ignore = []\n        gt_masks = []\n        gt_instance_ids = []\n\n        for i, ann in enumerate(ann_info):\n            if ann.get('ignore', False):\n                continue\n            x1, y1, w, h = ann['bbox']\n            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))\n            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))\n            if inter_w * inter_h == 0:\n                continue\n            if ann['area'] <= 0 or w < 1 or h < 1:\n                continue\n            if ann['category_id'] not in self.cat_ids:\n                continue\n            bbox = [x1, y1, x1 + w, y1 + h]\n            if ann.get('iscrowd', False):\n                gt_bboxes_ignore.append(bbox)\n            else:\n                gt_bboxes.append(bbox)\n                gt_labels.append(self.cat2label[ann['category_id']])\n                if 'segmentation' in ann:\n                    gt_masks.append(ann['segmentation'])\n                if 'instance_id' in ann:\n                    gt_instance_ids.append(ann['instance_id'])\n\n        if gt_bboxes:\n            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)\n            gt_labels = np.array(gt_labels, dtype=np.int64)\n        else:\n            gt_bboxes = np.zeros((0, 4), dtype=np.float32)\n            gt_labels = np.array([], dtype=np.int64)\n\n        if gt_bboxes_ignore:\n            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)\n        else:\n            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)\n\n        seg_map = img_info['filename'].replace('jpg', 'png')\n\n        ann = dict(\n            bboxes=gt_bboxes,\n            labels=gt_labels,\n            bboxes_ignore=gt_bboxes_ignore,\n            masks=gt_masks,\n            seg_map=seg_map)\n\n        if self.load_as_video:\n            ann['instance_ids'] = np.array(gt_instance_ids).astype(np.int)\n        else:\n            ann['instance_ids'] = np.arange(len(gt_labels))\n\n        return ann\n\n    def evaluate(self,\n                 results,\n                 metric=['bbox', 'track'],\n                 logger=None,\n                 bbox_kwargs=dict(\n                     classwise=False,\n                     proposal_nums=(100, 300, 1000),\n                     iou_thrs=None,\n                     metric_items=None),\n                 track_kwargs=dict(\n                     iou_thr=0.5,\n                     ignore_iof_thr=0.5,\n                     ignore_by_classes=False,\n                     nproc=4)):\n        \"\"\"Evaluation in COCO protocol and CLEAR MOT metric (e.g. MOTA, IDF1).\n\n        Args:\n            results (dict): Testing results of the dataset.\n            metric (str | list[str]): Metrics to be evaluated. Options are\n                'bbox', 'segm', 'track'.\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n            bbox_kwargs (dict): Configuration for COCO styple evaluation.\n            track_kwargs (dict): Configuration for CLEAR MOT evaluation.\n\n        Returns:\n            dict[str, float]: COCO style and CLEAR MOT evaluation metric.\n        \"\"\"\n        if isinstance(metric, list):\n            metrics = metric\n        elif isinstance(metric, str):\n            metrics = [metric]\n        else:\n            raise TypeError('metric must be a list or a str.')\n        allowed_metrics = ['bbox', 'segm', 'track']\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported.')\n\n        eval_results = dict()\n        if 'track' in metrics:\n            assert len(self.data_infos) == len(results['track_bboxes'])\n            inds = [\n                i for i, _ in enumerate(self.data_infos) if _['frame_id'] == 0\n            ]\n            num_vids = len(inds)\n            inds.append(len(self.data_infos))\n\n            track_bboxes = [\n                results['track_bboxes'][inds[i]:inds[i + 1]]\n                for i in range(num_vids)\n            ]\n            ann_infos = [self.get_ann_info(_) for _ in self.data_infos]\n            ann_infos = [\n                ann_infos[inds[i]:inds[i + 1]] for i in range(num_vids)\n            ]\n            raise NotImplementedError(\"eval_mot is not implemented yet.\")\n            # track_eval_results = eval_mot(\n            #     results=track_bboxes,\n            #     annotations=ann_infos,\n            #     logger=logger,\n            #     classes=self.CLASSES,\n            #     **track_kwargs)\n            # eval_results.update(track_eval_results)\n\n        # evaluate for detectors without tracker\n        super_metrics = ['bbox', 'segm']\n        super_metrics = [_ for _ in metrics if _ in super_metrics]\n        if super_metrics:\n            if isinstance(results, dict):\n                if 'bbox' in super_metrics and 'segm' in super_metrics:\n                    super_results = []\n                    for bbox, mask in zip(results['det_bboxes'],\n                                          results['det_masks']):\n                        super_results.append((bbox, mask))\n                else:\n                    super_results = results['det_bboxes']\n            elif isinstance(results, list):\n                super_results = results\n            else:\n                raise TypeError('Results must be a dict or a list.')\n            super_eval_results = super().evaluate(\n                results=super_results,\n                metric=super_metrics,\n                logger=logger,\n                **bbox_kwargs)\n            eval_results.update(super_eval_results)\n\n        return eval_results\n\n    def __repr__(self):\n        \"\"\"Print the number of instance number suit for video dataset.\"\"\"\n        dataset_type = 'Test' if self.test_mode else 'Train'\n        result = (f'\\n{self.__class__.__name__} {dataset_type} dataset '\n                  f'with number of images {len(self)}, '\n                  f'and instance counts: \\n')\n        if self.CLASSES is None:\n            result += 'Category names are not provided. \\n'\n            return result\n        instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)\n        # count the instance number in each image\n        for idx in range(len(self)):\n            img_info = self.data_infos[idx]\n            label = self.get_ann_info(img_info)['labels']\n            unique, counts = np.unique(label, return_counts=True)\n            if len(unique) > 0:\n                # add the occurrence number to each class\n                instance_count[unique] += counts\n            else:\n                # background is the last index\n                instance_count[-1] += 1\n        # create a table with category count\n        table_data = [['category', 'count'] * 5]\n        row_data = []\n        for cls, count in enumerate(instance_count):\n            if cls < len(self.CLASSES):\n                row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']\n            else:\n                # add the background number\n                row_data += ['-1 background', f'{count}']\n            if len(row_data) == 10:\n                table_data.append(row_data)\n                row_data = []\n        if len(row_data) >= 2:\n            if row_data[-1] == '0':\n                row_data = row_data[:-2]\n            if len(row_data) >= 2:\n                table_data.append([])\n                table_data.append(row_data)\n\n        table = AsciiTable(table_data)\n        result += table.table\n        return result\n"
  },
  {
    "path": "mmtrack/datasets/parsers/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .coco_video_parser import CocoVID\n\n__all__ = ['CocoVID']\n"
  },
  {
    "path": "mmtrack/datasets/parsers/coco_video_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom collections import defaultdict\n\nimport numpy as np\nfrom mmdet.datasets.api_wrappers import COCO\nfrom pycocotools.coco import _isArrayLike\n\n\nclass CocoVID(COCO):\n    \"\"\"Inherit official COCO class in order to parse the annotations of bbox-\n    related video tasks.\n    Args:\n        annotation_file (str): location of annotation file. Defaults to None.\n        load_img_as_vid (bool): If True, convert image data to video data,\n            which means each image is converted to a video. Defaults to False.\n    \"\"\"\n\n    def __init__(self, annotation_file=None, load_img_as_vid=False):\n        assert annotation_file, 'Annotation file must be provided.'\n        self.load_img_as_vid = load_img_as_vid\n        super(CocoVID, self).__init__(annotation_file=annotation_file)\n\n    def convert_img_to_vid(self, dataset):\n        \"\"\"Convert image data to video data.\"\"\"\n        if 'images' in self.dataset:\n            videos = []\n            for i, img in enumerate(self.dataset['images']):\n                videos.append(dict(id=img['id'], name=img['file_name']))\n                img['video_id'] = img['id']\n                img['frame_id'] = 0\n            dataset['videos'] = videos\n\n        if 'annotations' in self.dataset:\n            for i, ann in enumerate(self.dataset['annotations']):\n                ann['video_id'] = ann['image_id']\n                ann['instance_id'] = ann['id']\n        return dataset\n\n    def createIndex(self, use_ext=False):\n        \"\"\"Create index.\"\"\"\n        print('creating index...')\n        anns, cats, imgs, vids = {}, {}, {}, {}\n        (imgToAnns, catToImgs, vidToImgs, vidToInstances,\n         instancesToImgs) = defaultdict(list), defaultdict(list), defaultdict(\n            list), defaultdict(list), defaultdict(list)\n\n        if 'videos' not in self.dataset and self.load_img_as_vid:\n            self.dataset = self.convert_img_to_vid(self.dataset)\n\n        if 'videos' in self.dataset:\n            for video in self.dataset['videos']:\n                vids[video['id']] = video\n\n        if 'annotations' in self.dataset:\n            for ann in self.dataset['annotations']:\n                imgToAnns[ann['image_id']].append(ann)\n                anns[ann['id']] = ann\n                if 'instance_id' in ann:\n                    instancesToImgs[ann['instance_id']].append(ann['image_id'])\n                    if 'video_id' in ann and \\\n                            ann['instance_id'] not in \\\n                            vidToInstances[ann['video_id']]:\n                        vidToInstances[ann['video_id']].append(\n                            ann['instance_id'])\n\n        if 'images' in self.dataset:\n            for img in self.dataset['images']:\n                vidToImgs[img['video_id']].append(img)\n                imgs[img['id']] = img\n\n        if 'categories' in self.dataset:\n            for cat in self.dataset['categories']:\n                cats[cat['id']] = cat\n\n        if 'annotations' in self.dataset and 'categories' in self.dataset:\n            for ann in self.dataset['annotations']:\n                catToImgs[ann['category_id']].append(ann['image_id'])\n\n        print('index created!')\n\n        self.anns = anns\n        self.imgToAnns = imgToAnns\n        self.catToImgs = catToImgs\n        self.imgs = imgs\n        self.cats = cats\n        self.videos = vids\n        self.vidToImgs = vidToImgs\n        self.vidToInstances = vidToInstances\n        self.instancesToImgs = instancesToImgs\n\n    def get_vid_ids(self, vidIds=[]):\n        \"\"\"Get video ids that satisfy given filter conditions.\n        Default return all video ids.\n        Args:\n            vidIds (list[int]): The given video ids. Defaults to [].\n        Returns:\n            list[int]: Video ids.\n        \"\"\"\n        vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]\n\n        if len(vidIds) == 0:\n            ids = self.videos.keys()\n        else:\n            ids = set(vidIds)\n\n        return list(ids)\n\n    def get_img_ids_from_vid(self, vidId):\n        \"\"\"Get image ids from given video id.\n        Args:\n            vidId (int): The given video id.\n        Returns:\n            list[int]: Image ids of given video id.\n        \"\"\"\n        img_infos = self.vidToImgs[vidId]\n        ids = list(np.zeros([len(img_infos)], dtype=np.int64))\n        for img_info in img_infos:\n            ids[img_info['frame_id']] = img_info['id']\n        return ids\n\n    def get_ins_ids_from_vid(self, vidId):\n        \"\"\"Get instance ids from given video id.\n        Args:\n            vidId (int): The given video id.\n        Returns:\n            list[int]: Instance ids of given video id.\n        \"\"\"\n        return self.vidToInstances[vidId]\n\n    def get_img_ids_from_ins_id(self, insId):\n        \"\"\"Get image ids from given instance id.\n        Args:\n            insId (int): The given instance id.\n        Returns:\n            list[int]: Image ids of given instance id.\n        \"\"\"\n        return self.instancesToImgs[insId]\n\n    def load_vids(self, ids=[]):\n        \"\"\"Get video information of given video ids.\n        Default return all videos information.\n        Args:\n            ids (list[int]): The given video ids. Defaults to [].\n        Returns:\n            list[dict]: List of video information.\n        \"\"\"\n        if _isArrayLike(ids):\n            return [self.videos[id] for id in ids]\n        elif type(ids) == int:\n            return [self.videos[ids]]\n"
  },
  {
    "path": "mmtrack/datasets/youtube_vis_dataset.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os.path\nimport os.path as osp\nimport tempfile\nimport zipfile\n\nimport mmcv\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom mmdet.datasets import DATASETS\n\nfrom .coco_video_dataset import CocoVideoDataset\n\n\ndef results2outs(bbox_results=None,\n                 mask_results=None,\n                 mask_shape=None,\n                 **kwargs):\n    \"\"\"Restore the results (list of results of each category) into the results\n    of the model forward.\n    Args:\n        bbox_results (list[np.ndarray]): Each list denotes bboxes of one\n            category.\n        mask_results (list[list[np.ndarray]]): Each outer list denotes masks of\n            one category. Each inner list denotes one mask belonging to\n            the category. Each mask has shape (h, w).\n        mask_shape (tuple[int]): The shape (h, w) of mask.\n    Returns:\n        tuple: tracking results of each class. It may contain keys as belows:\n        - bboxes (np.ndarray): shape (n, 5)\n        - labels (np.ndarray): shape (n, )\n        - masks (np.ndarray): shape (n, h, w)\n        - ids (np.ndarray): shape (n, )\n    \"\"\"\n    outputs = dict()\n\n    if bbox_results is not None:\n        labels = []\n        for i, bbox in enumerate(bbox_results):\n            labels.extend([i] * bbox.shape[0])\n        labels = np.array(labels, dtype=np.int64)\n        outputs['labels'] = labels\n\n        bboxes = np.concatenate(bbox_results, axis=0).astype(np.float32)\n        if bboxes.shape[1] == 5:\n            outputs['bboxes'] = bboxes\n        elif bboxes.shape[1] == 6:\n            ids = bboxes[:, 0].astype(np.int64)\n            bboxes = bboxes[:, 1:]\n            outputs['bboxes'] = bboxes\n            outputs['ids'] = ids\n        else:\n            raise NotImplementedError(\n                f'Not supported bbox shape: (N, {bboxes.shape[1]})')\n\n    if mask_results is not None:\n        assert mask_shape is not None\n        mask_height, mask_width = mask_shape\n        mask_results = mmcv.concat_list(mask_results)\n        if len(mask_results) == 0:\n            masks = np.zeros((0, mask_height, mask_width)).astype(bool)\n        else:\n            masks = np.stack(mask_results, axis=0)\n        outputs['masks'] = masks\n\n    return outputs\n\n\n@DATASETS.register_module()\nclass YouTubeVISDataset(CocoVideoDataset):\n    \"\"\"YouTube VIS dataset for video instance segmentation.\"\"\"\n\n    CLASSES_2019_version = ('person', 'giant_panda', 'lizard', 'parrot', 'skateboard',\n                            'sedan', 'ape', 'dog', 'snake', 'monkey',\n                            'hand', 'rabbit', 'duck', 'cat', 'cow',\n                            'fish', 'train', 'horse', 'turtle', 'bear',\n                            'motorbike', 'giraffe', 'leopard', 'fox', 'deer',\n                            'owl', 'surfboard', 'airplane', 'truck', 'zebra',\n                            'tiger', 'elephant', 'snowboard', 'boat', 'shark',\n                            'mouse', 'frog', 'eagle', 'earless_seal', 'tennis_racket')\n\n    CLASSES_2021_version = ('airplane', 'bear', 'bird', 'boat', 'car', 'cat',\n                            'cow', 'deer', 'dog', 'duck', 'earless_seal',\n                            'elephant', 'fish', 'flying_disc', 'fox', 'frog',\n                            'giant_panda', 'giraffe', 'horse', 'leopard',\n                            'lizard', 'monkey', 'motorbike', 'mouse', 'parrot',\n                            'person', 'rabbit', 'shark', 'skateboard', 'snake',\n                            'snowboard', 'squirrel', 'surfboard',\n                            'tennis_racket', 'tiger', 'train', 'truck',\n                            'turtle', 'whale', 'zebra')\n\n    def __init__(self, dataset_version, *args, **kwargs):\n        self.set_dataset_classes(dataset_version)\n        super().__init__(*args, **kwargs)\n\n    @classmethod\n    def set_dataset_classes(cls, dataset_version):\n        if dataset_version == '2019':\n            cls.CLASSES = cls.CLASSES_2019_version\n        elif dataset_version == '2021':\n            cls.CLASSES = cls.CLASSES_2021_version\n        else:\n            raise NotImplementedError('Not supported YouTubeVIS dataset'\n                                      f'version: {dataset_version}')\n\n    def format_results(self,\n                       _results,\n                       resfile_path=None,\n                       metrics=['track_segm']):\n        \"\"\"Format the results to a zip file (standard format for YouTube-VIS\n        Challenge).\n        Args:\n            results (dict(list[ndarray])): Testing results of the dataset.\n            resfile_path (str, optional): Path to save the formatted results.\n                Defaults to None.\n            metrics (list[str], optional): The results of the specific metrics\n                will be formatted. Defaults to ['track_segm'].\n        Returns:\n            tuple: (resfiles, tmp_dir), resfiles is the path of the result\n            json file, tmp_dir is the temporal directory created for saving\n            files.\n        \"\"\"\n        results = {\n            'track_bboxes':[item[0] for item in _results],\n            'track_masks':[item[1] for item in _results]\n        }\n        data_infos = []\n        for item in self.data_infos:\n            data_infos.extend(item[1:])\n        assert isinstance(results, dict), 'results must be a dict.'\n        if isinstance(metrics, str):\n            metrics = [metrics]\n        assert 'track_segm' in metrics\n        if resfile_path is None:\n            tmp_dir = tempfile.TemporaryDirectory()\n            resfile_path = tmp_dir.name\n        else:\n            tmp_dir = None\n            if not os.path.exists(resfile_path):\n                os.makedirs(resfile_path)\n        resfiles = osp.join(resfile_path, 'results.json')\n\n        inds = [i for i, _ in enumerate(data_infos) if _['frame_id'] == 0]\n        num_vids = len(inds)\n        assert num_vids == len(self.vid_ids)\n        inds.append(len(data_infos))\n        vid_infos = self.coco.load_vids(self.vid_ids)\n\n        json_results = []\n        for i in range(num_vids):\n            video_id = vid_infos[i]['id']\n            # collect data for each instances in a video.\n            collect_data = dict()\n            for frame_id, (bbox_res, mask_res) in enumerate(\n                    zip(results['track_bboxes'][inds[i]:inds[i + 1]],\n                        results['track_masks'][inds[i]:inds[i + 1]])):\n                outs_track = results2outs(bbox_results=bbox_res)\n                bboxes = outs_track['bboxes']\n                labels = outs_track['labels']\n                ids = outs_track['ids']\n                masks = mmcv.concat_list(mask_res)\n                assert len(masks) == len(bboxes)\n                for j, id in enumerate(ids):\n                    if id not in collect_data:\n                        collect_data[id] = dict(\n                            category_ids=[], scores=[], segmentations=dict())\n                    collect_data[id]['category_ids'].append(labels[j])\n                    collect_data[id]['scores'].append(bboxes[j][4])\n                    if isinstance(masks[j]['counts'], bytes):\n                        masks[j]['counts'] = masks[j]['counts'].decode()\n                    collect_data[id]['segmentations'][frame_id] = masks[j]\n\n            # transform the collected data into official format\n            for id, id_data in collect_data.items():\n                output = dict()\n                output['video_id'] = video_id\n                output['score'] = np.array(id_data['scores']).mean().item()\n                # majority voting for sequence category\n                output['category_id'] = np.bincount(\n                    np.array(id_data['category_ids'])).argmax().item() + 1\n                output['segmentations'] = []\n                for frame_id in range(inds[i + 1] - inds[i]):\n                    if frame_id in id_data['segmentations']:\n                        output['segmentations'].append(\n                            id_data['segmentations'][frame_id])\n                    else:\n                        output['segmentations'].append(None)\n                json_results.append(output)\n        mmcv.dump(json_results, resfiles)\n\n        # zip the json file in order to submit to the test server.\n        zip_file_name = osp.join(resfile_path, 'submission_file.zip')\n        zf = zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED)\n        print_log(f\"zip the 'results.json' into '{zip_file_name}', \"\n                  'please submmit the zip file to the test server')\n        zf.write(resfiles, 'results.json')\n        zf.close()\n\n        return resfiles, tmp_dir\n"
  },
  {
    "path": "mmtrack/pipelines/__init__.py",
    "content": "from .formatting import *\nfrom .loading import *\nfrom .test_time_aug import *\nfrom .transforms import *"
  },
  {
    "path": "mmtrack/pipelines/formatting.py",
    "content": "import numpy as np\nimport torch\nfrom mmcv.parallel import DataContainer as DC\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import to_tensor\n\n\n@PIPELINES.register_module()\nclass ConcatVideoReferences(object):\n    \"\"\"Concat video references.\n\n    If the input list contains at least two dicts, concat the input list of\n    dict to one dict from 2-nd dict of the input list.\n\n    Args:\n        results (list[dict]): List of dict that contain keys such as 'img',\n            'img_metas', 'gt_masks','proposals', 'gt_bboxes',\n            'gt_bboxes_ignore', 'gt_labels','gt_semantic_seg',\n            'gt_instance_ids'.\n\n    Returns:\n        list[dict]: The first dict of outputs is the same as the first\n        dict of `results`. The second dict of outputs concats the\n        dicts in `results[1:]`.\n    \"\"\"\n\n    def __call__(self, results):\n        assert (isinstance(results, list)), 'results must be list'\n        outs = results[:1]\n        for i, result in enumerate(results[1:], 1):\n            if 'img' in result:\n                img = result['img']\n                if len(img.shape) < 3:\n                    img = np.expand_dims(img, -1)\n                if i == 1:\n                    result['img'] = np.expand_dims(img, -1)\n                else:\n                    outs[1]['img'] = np.concatenate(\n                        (outs[1]['img'], np.expand_dims(img, -1)), axis=-1)\n            for key in ['img_metas', 'gt_masks']:\n                if key in result:\n                    if i == 1:\n                        result[key] = [result[key]]\n                    else:\n                        outs[1][key].append(result[key])\n            for key in [\n                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',\n                'gt_instance_ids',\n            ]:\n                if key not in result:\n                    continue\n                value = result[key]\n                if value.ndim == 1:\n                    value = value[:, None]\n                N = value.shape[0]\n                value = np.concatenate((np.full(\n                    (N, 1), i - 1, dtype=int if key in ['gt_labels', 'gt_instance_ids'] else np.float32\n                ), value), axis=1)\n                if i == 1:\n                    result[key] = value\n                else:\n                    outs[1][key] = np.concatenate((outs[1][key], value), axis=0)\n            if 'gt_semantic_seg' in result:\n                if i == 1:\n                    result['gt_semantic_seg'] = result['gt_semantic_seg'][..., None, None]\n                else:\n                    outs[1]['gt_semantic_seg'] = np.concatenate(\n                        (outs[1]['gt_semantic_seg'], result['gt_semantic_seg'][..., None, None]), axis=-1)\n            if i == 1:\n                outs.append(result)\n        return outs\n\n\n@PIPELINES.register_module()\nclass ConcatVideos(object):\n    \"\"\"Concat video references.\n\n    If the input list contains at least two dicts, concat the input list of\n    dict to one dict from 2-nd dict of the input list.\n\n    Args:\n        results (list[dict]): List of dict that contain keys such as 'img',\n            'img_metas', 'gt_masks','proposals', 'gt_bboxes',\n            'gt_bboxes_ignore', 'gt_labels','gt_semantic_seg',\n            'gt_instance_ids'.\n\n    Returns:\n        list[dict]: The first dict of outputs is the same as the first\n        dict of `results`. The second dict of outputs concats the\n        dicts in `results[1:]`.\n    \"\"\"\n\n    def __call__(self, results):\n        assert (isinstance(results, list)), 'results must be list'\n        outs = results[:1]\n        # outs = []\n        for i, result in enumerate(results[0:], 1):\n            if 'img' in result:\n                img = result['img']\n                if len(img.shape) < 3:\n                    img = np.expand_dims(img, -1)\n                if i == 1:\n                    result['img'] = np.expand_dims(img, -1)\n                else:\n                    outs[1]['img'] = np.concatenate(\n                        (outs[1]['img'], np.expand_dims(img, -1)), axis=-1)\n            for key in ['img_metas', 'gt_masks']:\n                if key in result:\n                    if i == 1:\n                        result[key] = [result[key]]\n                    else:\n                        outs[1][key].append(result[key])\n            for key in [\n                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',\n                'gt_instance_ids'\n            ]:\n                if key not in result:\n                    continue\n                value = result[key]\n                if value.ndim == 1:\n                    value = value[:, None]\n                N = value.shape[0]\n                value = np.concatenate((np.full(\n                    (N, 1), i - 1, dtype=int if key in ['gt_labels', 'gt_instance_ids'] else np.float32\n                ), value), axis=1)\n                if i == 1:\n                    result[key] = value\n                else:\n                    outs[1][key] = np.concatenate((outs[1][key], value),\n                                                  axis=0)\n            if 'gt_semantic_seg' in result:\n                if i == 1:\n                    result['gt_semantic_seg'] = result['gt_semantic_seg'][...,\n                                                                          None,\n                                                                          None]\n                else:\n                    outs[1]['gt_semantic_seg'] = np.concatenate(\n                        (outs[1]['gt_semantic_seg'],\n                         result['gt_semantic_seg'][..., None, None]),\n                        axis=-1)\n            if i == 1:\n                outs.append(result)\n        res = []\n        res.append(outs[1])\n        return res\n\n\n@PIPELINES.register_module()\nclass MultiImagesToTensor(object):\n    \"\"\"Multi images to tensor.\n\n    1. Transpose and convert image/multi-images to Tensor.\n    2. Add prefix to every key in the second dict of the inputs. Then, add\n    these keys and corresponding values into the outputs.\n\n    Args:\n        ref_prefix (str): The prefix of key added to the second dict of inputs.\n            Defaults to 'ref'.\n    \"\"\"\n\n    def __init__(self, ref_prefix='ref'):\n        self.ref_prefix = ref_prefix\n\n    def __call__(self, results):\n        \"\"\"Multi images to tensor.\n\n        1. Transpose and convert image/multi-images to Tensor.\n        2. Add prefix to every key in the second dict of the inputs. Then, add\n        these keys and corresponding values into the output dict.\n\n        Args:\n            results (list[dict]): List of two dicts.\n\n        Returns:\n            dict: Each key in the first dict of `results` remains unchanged.\n            Each key in the second dict of `results` adds `self.ref_prefix`\n            as prefix.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = self.images_to_tensor(_results)\n            outs.append(_results)\n\n        data = {}\n        data.update(outs[0])\n        if len(outs) == 2:\n            for k, v in outs[1].items():\n                data[f'{self.ref_prefix}_{k}'] = v\n\n        return data\n\n    def images_to_tensor(self, results):\n        \"\"\"Transpose and convert images/multi-images to Tensor.\"\"\"\n        if 'img' in results:\n            img = results['img']\n            if len(img.shape) == 3:\n                # (H, W, 3) to (3, H, W)\n                img = np.ascontiguousarray(img.transpose(2, 0, 1))\n            else:\n                # (H, W, 3, N) to (N, 3, H, W)\n                img = np.ascontiguousarray(img.transpose(3, 2, 0, 1))\n            results['img'] = to_tensor(img)\n        if 'proposals' in results:\n            results['proposals'] = to_tensor(results['proposals'])\n        if 'img_metas' in results:\n            results['img_metas'] = DC(results['img_metas'], cpu_only=True)\n        return results\n\n\n@PIPELINES.register_module()\nclass SeqDefaultFormatBundle(object):\n    \"\"\"Sequence Default formatting bundle.\n\n    It simplifies the pipeline of formatting common fields, including \"img\",\n    \"img_metas\", \"proposals\", \"gt_bboxes\", \"gt_instance_ids\",\n    \"gt_match_indices\", \"gt_bboxes_ignore\", \"gt_labels\", \"gt_masks\" and\n    \"gt_semantic_seg\". These fields are formatted as follows.\n\n    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)\n    - img_metas: (1) to DataContainer (cpu_only=True)\n    - proposals: (1) to tensor, (2) to DataContainer\n    - gt_bboxes: (1) to tensor, (2) to DataContainer\n    - gt_instance_ids: (1) to tensor, (2) to DataContainer\n    - gt_match_indices: (1) to tensor, (2) to DataContainer\n    - gt_bboxes_ignore: (1) to tensor, (2) to DataContainer\n    - gt_labels: (1) to tensor, (2) to DataContainer\n    - gt_masks: (1) to DataContainer (cpu_only=True)\n    - gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \\\n                       (3) to DataContainer (stack=True)\n\n    Args:\n        ref_prefix (str): The prefix of key added to the second dict of input\n            list. Defaults to 'ref'.\n    \"\"\"\n\n    def __init__(self, ref_prefix='ref'):\n        self.ref_prefix = ref_prefix\n\n    def __call__(self, results):\n        \"\"\"Sequence Default formatting bundle call function.\n\n        Args:\n            results (list[dict]): List of two dicts.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            default bundle. Each key in the second dict of the input list\n            adds `self.ref_prefix` as prefix.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = self.default_format_bundle(_results)\n            outs.append(_results)\n\n        data = {}\n        if self.ref_prefix == 'ref':\n            # origin frames\n            data.update(outs[0])\n            # reference frames\n            if len(outs) == 1:\n                # for k in outs[0]:\n                #     data[f'{self.ref_prefix}_{k}'] = None\n                pass\n            else:\n                for k, v in outs[1].items():\n                    data[f'{self.ref_prefix}_{k}'] = v\n        elif self.ref_prefix is None:\n            # origin frames\n            data.update(outs[0])\n\n        return data\n\n    def default_format_bundle(self, results):\n        \"\"\"Transform and format common fields in results.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            default bundle.\n        \"\"\"\n        if 'img' in results:\n            img = results['img']\n            if len(img.shape) == 3:\n                img = np.ascontiguousarray(img.transpose(2, 0, 1))\n            else:\n                img = np.ascontiguousarray(img.transpose(3, 2, 0, 1))\n            results['img'] = DC(to_tensor(img), stack=True)\n        for key in [\n            'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',\n            'gt_instance_ids', 'gt_match_indices',\n        ]:\n            if key not in results:\n                continue\n            results[key] = DC(to_tensor(results[key]))\n        for key in ['img_metas', 'gt_masks']:\n            if key in results:\n                results[key] = DC(results[key], cpu_only=True)\n        if 'gt_semantic_seg' in results:\n            semantic_seg = results['gt_semantic_seg']\n            if len(semantic_seg.shape) == 2:\n                semantic_seg = semantic_seg[None, ...]\n            else:\n                semantic_seg = np.ascontiguousarray(\n                    semantic_seg.transpose(3, 2, 0, 1))\n            results['gt_semantic_seg'] = DC(\n                to_tensor(semantic_seg), stack=True)\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\n@PIPELINES.register_module()\nclass VideoCollect(object):\n    \"\"\"Collect data from the loader relevant to the specific task.\n\n    Args:\n        keys (Sequence[str]): Keys of results to be collected in ``data``.\n        meta_keys (Sequence[str]): Meta keys to be converted to\n            ``mmcv.DataContainer`` and collected in ``data[img_metas]``.\n            Defaults to None.\n        default_meta_keys (tuple): Default meta keys. Defaults to ('filename',\n            'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',\n            'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',\n            'frame_id', 'is_video_data').\n    \"\"\"\n\n    def __init__(self,\n                 keys,\n                 meta_keys=None,\n                 reject_empty=False,\n                 num_ref_imgs=0,\n                 # no_obj_class is added for handling non-0  no-obj class\n                 default_meta_keys=('filename', 'ori_filename', 'ori_shape',\n                                    'img_shape', 'pad_shape', 'scale_factor',\n                                    'flip', 'flip_direction', 'img_norm_cfg',\n                                    'video_id',\n                                    'frame_id', 'is_video_data', 'no_obj_class')):\n        self.keys = keys\n        self.meta_keys = default_meta_keys\n        if meta_keys is not None:\n            if isinstance(meta_keys, str):\n                meta_keys = (meta_keys,)\n            else:\n                assert isinstance(meta_keys, tuple), \\\n                    'meta_keys must be str or tuple'\n            self.meta_keys += meta_keys\n\n        self.reject_empty = reject_empty\n        self.num_ref_imgs = num_ref_imgs\n\n    def __call__(self, results):\n        \"\"\"Call function to collect keys in results.\n\n        The keys in ``meta_keys`` and ``default_meta_keys`` will be converted\n        to :obj:mmcv.DataContainer.\n\n        Args:\n            results (list[dict] | dict): List of dict or dict which contains\n                the data to collect.\n\n        Returns:\n            list[dict] | dict: List of dict or dict that contains the\n            following keys:\n\n            - keys in ``self.keys``\n            - ``img_metas``\n        \"\"\"\n        results_is_dict = isinstance(results, dict)\n        if results_is_dict:\n            results = [results]\n        outs = []\n        for _results in results:\n            _results = self._add_default_meta_keys(_results)\n            _results = self._collect_meta_keys(_results)\n            outs.append(_results)\n\n        if results_is_dict:\n            outs[0]['img_metas'] = DC(outs[0]['img_metas'], cpu_only=True)\n\n        if self.reject_empty:\n            if len(results[0]['gt_labels']) == 0:\n                return None\n        if self.num_ref_imgs > 0:\n            if len(results) != self.num_ref_imgs + 1:\n                raise NotImplementedError\n        return outs[0] if results_is_dict else outs\n\n    def _collect_meta_keys(self, results):\n        \"\"\"Collect `self.keys` and `self.meta_keys` from `results` (dict).\"\"\"\n        data = {}\n        img_meta = {}\n        for key in self.meta_keys:\n            if key in results:\n                img_meta[key] = results[key]\n            elif key in results['img_info']:\n                img_meta[key] = results['img_info'][key]\n        data['img_metas'] = img_meta\n        for key in self.keys:\n            data[key] = results[key]\n        return data\n\n    def _add_default_meta_keys(self, results):\n        \"\"\"Add default meta keys.\n\n        We set default meta keys including `pad_shape`, `scale_factor` and\n        `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and\n        `Pad` are implemented during the whole pipeline.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            results (dict): Updated result dict contains the data to convert.\n        \"\"\"\n        img = results['img']\n        results.setdefault('pad_shape', img.shape)\n        results.setdefault('scale_factor', 1.0)\n        num_channels = 1 if len(img.shape) < 3 else img.shape[2]\n        results.setdefault(\n            'img_norm_cfg',\n            dict(\n                mean=np.zeros(num_channels, dtype=np.float32),\n                std=np.ones(num_channels, dtype=np.float32),\n                to_rgb=False))\n        return results\n\n\n@PIPELINES.register_module()\nclass ToList(object):\n    \"\"\"Use list to warp each value of the input dict.\n\n    Args:\n        results (dict): Result dict contains the data to convert.\n\n    Returns:\n        dict: Updated result dict contains the data to convert.\n    \"\"\"\n\n    def __call__(self, results):\n        out = {}\n        for k, v in results.items():\n            out[k] = [v]\n        return out\n\n\n@PIPELINES.register_module()\nclass ReIDFormatBundle(object):\n    \"\"\"ReID formatting bundle.\n\n    It first concatenates common fields, then simplifies the pipeline of\n    formatting common fields, including \"img\", and \"gt_label\".\n    These fields are formatted as follows.\n\n    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)\n    - gt_labels: (1) to tensor, (2) to DataContainer\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def __call__(self, results):\n        \"\"\"ReID formatting bundle call function.\n\n        Args:\n            results (list[dict] or dict): List of dicts or dict.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            ReID bundle.\n        \"\"\"\n        inputs = dict()\n        if isinstance(results, list):\n            assert len(results) > 1, \\\n                'the \\'results\\' only have one item, ' \\\n                'please directly use normal pipeline not \\'Seq\\' pipeline.'\n            inputs['img'] = np.stack([_results['img'] for _results in results],\n                                     axis=3)\n            inputs['gt_label'] = np.stack(\n                [_results['gt_label'] for _results in results], axis=0)\n        elif isinstance(results, dict):\n            inputs['img'] = results['img']\n            inputs['gt_label'] = results['gt_label']\n        else:\n            raise TypeError('results must be a list or a dict.')\n        outs = self.reid_format_bundle(inputs)\n\n        return outs\n\n    def reid_format_bundle(self, results):\n        \"\"\"Transform and format gt_label fields in results.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n            ReID bundle.\n        \"\"\"\n        for key in results:\n            if key == 'img':\n                img = results[key]\n                if img.ndim == 3:\n                    img = np.ascontiguousarray(img.transpose(2, 0, 1))\n                else:\n                    img = np.ascontiguousarray(img.transpose(3, 2, 0, 1))\n                results['img'] = DC(to_tensor(img), stack=True)\n            elif key == 'gt_label':\n                results[key] = DC(\n                    to_tensor(results[key]), stack=True, pad_dims=None)\n            else:\n                raise KeyError(f'key {key} is not supported')\n        return results\n\n\n@PIPELINES.register_module()\nclass ImageToTensorWithRef(object):\n\n    def __init__(self, keys):\n        self.keys = keys\n\n    def __call__(self, results):\n\n        for key in self.keys:\n            if key in ['ref_img']:\n                if isinstance(results[key], list):\n                    img_ref = []\n                    for img in results[key]:\n                        img = np.ascontiguousarray(img.transpose(2, 0, 1))\n                        img_ref.append(img)\n                    img_ref = np.array(img_ref)\n                    results[key] = to_tensor(img_ref)\n                else:\n                    img = np.ascontiguousarray(results[key].transpose(2, 0, 1))\n                    results[key] = to_tensor(img)\n            else:\n                results[key] = to_tensor(results[key].transpose(2, 0, 1))\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(keys={})'.format(self.keys)\n\n@PIPELINES.register_module()\nclass LabelConsistentChecker:\n    \"\"\"This module is to make the annotations are consistent in each video.\n    \"\"\"\n    def __init__(self, num_frames=5):\n        self.num_frames = num_frames\n\n    def __call__(self, results):\n        ref_gt_instance_ids = results['ref_gt_instance_ids'].data\n        ins_mul_nframe = ref_gt_instance_ids.size(0)\n        if ins_mul_nframe % self.num_frames != 0:\n            return None\n        num_ins = ins_mul_nframe // self.num_frames\n        ins_id_bucket = torch.zeros((num_ins,), dtype=torch.float)\n        for i in range(ins_mul_nframe):\n            frame_cur = i // num_ins\n            ins_cur = i % num_ins\n            if ref_gt_instance_ids[i][0] != frame_cur:\n                return None\n            if frame_cur == 0:\n                ins_id_bucket[ins_cur] = ref_gt_instance_ids[i][1]\n            else:\n                if ref_gt_instance_ids[i][1] != ins_id_bucket[ins_cur]:\n                    return None\n        return results\n\n@PIPELINES.register_module()\nclass MM2CLIP:\n    \"\"\"This module is to make the annotations are consistent in each video.\n    \"\"\"\n    def __init__(self, num_frames=5):\n        self.num_frames = num_frames\n\n    def __call__(self, results):\n        ins_ids = np.unique(results[1]['gt_instance_ids'][:,1])\n        num_ins = len(ins_ids)\n        num_frames = len(results[1]['img_metas'])\n        ins_id_bucket = np.zeros((num_ins,), dtype=float)\n        for i in range(num_ins * num_frames):\n            frame_cur = i // num_ins\n            ins_cur = i % num_ins\n            if results[1]['gt_instance_ids'][i][0] != frame_cur:\n                return None\n            if frame_cur == 0:\n                ins_id_bucket[ins_cur] = results[1]['gt_instance_ids'][i][1]\n            else:\n                if results[1]['gt_instance_ids'][i][1] != ins_id_bucket[ins_cur]:\n                    return None\n        return results\n\n"
  },
  {
    "path": "mmtrack/pipelines/loading.py",
    "content": "import os.path as osp\nimport numpy as np\n\nimport mmcv\nfrom mmdet.core import BitmapMasks\n\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile\n\n\n@PIPELINES.register_module()\nclass LoadMultiImagesFromFile(LoadImageFromFile):\n    \"\"\"Load multi images from file.\n    Please refer to `mmdet.datasets.pipelines.loading.py:LoadImageFromFile`\n    for detailed docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in `results`, call the call function of\n        `LoadImageFromFile` to load image.\n        Args:\n            results (list[dict]): List of dict from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains loaded image.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqLoadAnnotations(LoadAnnotations):\n    \"\"\"Sequence load annotations.\n    Please refer to `mmdet.datasets.pipelines.loading.py:LoadAnnotations`\n    for detailed docstring.\n    Args:\n        with_track (bool): If True, load instance ids of bboxes.\n    \"\"\"\n\n    def __init__(self, with_track=False, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.with_track = with_track\n\n    def _load_track(self, results):\n        \"\"\"Private function to load label annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            dict: The dict contains loaded label annotations.\n        \"\"\"\n\n        results['gt_instance_ids'] = results['ann_info']['instance_ids'].copy()\n\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `LoadAnnotations`\n        to load annotation.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains loaded annotations, such as\n            bounding boxes, labels, instance ids, masks and semantic\n            segmentation annotations.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            if self.with_track:\n                _results = self._load_track(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass LoadRefImageFromFile(object):\n    \"\"\"\n    Code reading reference frame information.\n    Specific to Cityscapes-VPS, Cityscapes, and VIPER datasets.\n    \"\"\"\n\n    def __init__(self, sample=True, to_float32=False):\n        self.to_float32 = to_float32\n        self.sample = sample\n\n    def __call__(self, results):\n        # requires dirname for ref images\n        assert results['ref_prefix'] is not None, 'ref_prefix must be specified.'\n\n        filename = osp.join(results['img_prefix'],\n                            results['img_info']['filename'])\n        img = mmcv.imread(filename)\n        # if specified by another ref json file.\n        if 'ref_filename' in results['img_info']:\n            ref_filename = osp.join(results['ref_prefix'],\n                                    results['img_info']['ref_filename'])\n            ref_img = mmcv.imread(ref_filename)  # [1024, 2048, 3]\n        else:\n            raise NotImplementedError('We need this implementation.')\n\n        if self.to_float32:\n            img = img.astype(np.float32)\n            ref_img = ref_img.astype(np.float32)\n\n        results['filename'] = filename\n        results['ori_filename'] = results['img_info']['filename']\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['ori_shape'] = img.shape\n        results['ref_img'] = ref_img\n        results['iid'] = results['img_info']['id']\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(to_float32={})'.format(\n            self.to_float32)\n\n\ndef bitmasks2bboxes(bitmasks):\n    bitmasks_array = bitmasks.masks\n    boxes = np.zeros((bitmasks_array.shape[0], 4), dtype=np.float32)\n    x_any = np.any(bitmasks_array, axis=1)\n    y_any = np.any(bitmasks_array, axis=2)\n    for idx in range(bitmasks_array.shape[0]):\n        x = np.where(x_any[idx, :])[0]\n        y = np.where(y_any[idx, :])[0]\n        if len(x) > 0 and len(y) > 0:\n            boxes[idx, :] = np.array((x[0], y[0], x[-1], y[-1]), dtype=np.float32)\n    return boxes\n\n\n@PIPELINES.register_module()\nclass LoadAnnotationsInstanceMasks:\n    def __init__(self,\n                 with_mask=True,\n                 with_seg=True,\n                 with_inst=False,\n                 cherry=None,\n                 file_client_args=dict(backend='disk')):\n        self.with_mask = with_mask\n        self.with_seg = with_seg\n        self.with_inst = with_inst\n        self.file_client_args = file_client_args.copy()\n        self.cherry = cherry\n        self.file_client = None\n\n    def _load_masks(self, results):\n        \"\"\"Private function to load mask annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded mask annotations.\n                If ``self.poly2mask`` is set ``True``, `gt_mask` will contain\n                :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.\n        \"\"\"\n\n        img_bytes = self.file_client.get(results['ann_info']['inst_map'])\n        inst_mask = mmcv.imfrombytes(img_bytes, flag='unchanged').squeeze()\n        if self.with_inst:\n            results['gt_instance_map'] = inst_mask.copy().astype(int)\n            results['gt_instance_map'][inst_mask < 10000] *= 1000\n        if not self.with_mask:\n            return results\n        masks = []\n        labels = []\n        for inst_id in np.unique(inst_mask):\n            if inst_id >= 10000:\n                if self.cherry is not None and not (inst_id // 1000 in self.cherry):\n                    continue\n                masks.append((inst_mask == inst_id).astype(int))\n                labels.append(inst_id // 1000)\n        if len(masks) == 0:\n            return None\n        gt_masks = BitmapMasks(masks, height=inst_mask.shape[0], width=inst_mask.shape[1])\n        results['gt_masks'] = gt_masks\n        results['mask_fields'].append('gt_masks')\n        results['gt_labels'] = np.array(labels)\n\n        boxes = bitmasks2bboxes(gt_masks)\n        results['gt_bboxes'] = boxes\n        results['bbox_fields'].append('gt_bboxes')\n        return results\n\n    def _load_semantic_seg(self, results):\n        \"\"\"Private function to load semantic segmentation annotations.\n        Args:\n            results (dict): Result dict from :obj:`dataset`.\n        Returns:\n            dict: The dict contains loaded semantic segmentation annotations.\n        \"\"\"\n        img_bytes = self.file_client.get(results['ann_info']['seg_map'])\n        results['gt_semantic_seg'] = mmcv.imfrombytes(\n            img_bytes, flag='unchanged').squeeze()\n        results['seg_fields'].append('gt_semantic_seg')\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to load multiple types annotations.\n        Args:\n            results (dict): Result dict from :obj:`mmdet.CustomDataset`.\n        Returns:\n            dict: The dict contains loaded bounding box, label, mask and\n                semantic segmentation annotations.\n        \"\"\"\n        if self.file_client is None:\n            self.file_client = mmcv.FileClient(**self.file_client_args)\n        if self.with_mask or self.with_inst:\n            results = self._load_masks(results)\n            if results is None:\n                return None\n        if self.with_seg:\n            results = self._load_semantic_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'with_mask={self.with_mask}, '\n        repr_str += f'with_seg={self.with_seg}, '\n        return repr_str\n"
  },
  {
    "path": "mmtrack/pipelines/test_time_aug.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport warnings\n\nimport mmcv\n\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import Compose\n\n\n@PIPELINES.register_module()\nclass MultiScaleFlipAugVideo:\n    \"\"\"Test-time augmentation with multiple scales and flipping.\n    An example configuration is as followed:\n    .. code-block::\n        img_scale=[(1333, 400), (1333, 800)],\n        flip=True,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ]\n    After MultiScaleFLipAug with above configuration, the results are wrapped\n    into lists of the same length as followed:\n    .. code-block::\n        dict(\n            img=[...],\n            img_shape=[...],\n            scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]\n            flip=[False, True, False, True]\n            ...\n        )\n    Args:\n        transforms (list[dict]): Transforms to apply in each augmentation.\n        img_scale (tuple | list[tuple] | None): Images scales for resizing.\n        scale_factor (float | list[float] | None): Scale factors for resizing.\n        flip (bool): Whether apply flip augmentation. Default: False.\n        flip_direction (str | list[str]): Flip augmentation directions,\n            options are \"horizontal\", \"vertical\" and \"diagonal\". If\n            flip_direction is a list, multiple flip augmentations will be\n            applied. It has no effect when flip == False. Default:\n            \"horizontal\".\n    \"\"\"\n\n    def __init__(self,\n                 transforms,\n                 img_scale=None,\n                 scale_factor=None,\n                 flip=False,\n                 flip_direction='horizontal'):\n        self.transforms = Compose(transforms)\n        assert (img_scale is None) ^ (scale_factor is None), (\n            'Must have but only one variable can be set')\n        if img_scale is not None:\n            self.img_scale = img_scale if isinstance(img_scale,\n                                                     list) else [img_scale]\n            self.scale_key = 'scale'\n            assert mmcv.is_list_of(self.img_scale, tuple)\n        else:\n            self.img_scale = scale_factor if isinstance(\n                scale_factor, list) else [scale_factor]\n            self.scale_key = 'scale_factor'\n\n        self.flip = flip\n        self.flip_direction = flip_direction if isinstance(\n            flip_direction, list) else [flip_direction]\n        assert mmcv.is_list_of(self.flip_direction, str)\n        if not self.flip and self.flip_direction != ['horizontal']:\n            warnings.warn(\n                'flip_direction has no effect when flip is set to False')\n        if (self.flip\n                and not any([t['type'] == 'RandomFlip' for t in transforms])):\n            warnings.warn(\n                'flip has no effect when RandomFlip is not in transforms')\n\n    def __call__(self, results):\n        \"\"\"Call function to apply test time augment transforms on results.\n        Args:\n            results (dict): Result dict contains the data to transform.\n        Returns:\n           dict[str: list]: The augmented data, where each value is wrapped\n               into a list.\n        \"\"\"\n\n        aug_data = []\n        flip_args = [(False, None)]\n        if self.flip:\n            flip_args += [(True, direction)\n                          for direction in self.flip_direction]\n        for scale in self.img_scale:\n            for flip, direction in flip_args:\n                _results = []\n                for results_single in results:\n                    _results_single = results_single.copy()\n                    _results_single[self.scale_key] = scale\n                    _results_single['flip'] = flip\n                    _results_single['flip_direction'] = direction\n                    _results.append(_results_single)\n                data = self.transforms(_results)\n                aug_data.append(data)\n        # list of dict to dict of list\n        aug_data_dict = {key: [] for key in aug_data[0]}\n        for data in aug_data:\n            for key, val in data.items():\n                aug_data_dict[key].append(val)\n        return aug_data_dict\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(transforms={self.transforms}, '\n        repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '\n        repr_str += f'flip_direction={self.flip_direction})'\n        return repr_str"
  },
  {
    "path": "mmtrack/pipelines/transforms.py",
    "content": "import cv2\nimport mmcv\nimport numpy as np\nimport warnings\nfrom mmdet.datasets.builder import PIPELINES\nfrom mmdet.datasets.pipelines import Normalize, Pad, RandomFlip, Resize\n\n\n@PIPELINES.register_module()\nclass SeqColorAug(object):\n    \"\"\"Color augmention for images.\n    Args:\n        prob (list[float]): The probability to perform color augmention for\n            each image. Defaults to [1.0, 1.0].\n        rgb_var (list[list]]): The values of color augmentaion. Defaults to\n            [[-0.55919361, 0.98062831, -0.41940627],\n            [1.72091413, 0.19879334, -1.82968581],\n            [4.64467907, 4.73710203, 4.88324118]].\n    \"\"\"\n\n    def __init__(self,\n                 prob=[1.0, 1.0],\n                 rgb_var=[[-0.55919361, 0.98062831, -0.41940627],\n                          [1.72091413, 0.19879334, -1.82968581],\n                          [4.64467907, 4.73710203, 4.88324118]]):\n        self.prob = prob\n        self.rgb_var = np.array(rgb_var, dtype=np.float32)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, perform color augmention for image in the\n        dict.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains augmented color image.\n        \"\"\"\n        outs = []\n        for i, _results in enumerate(results):\n            image = _results['img']\n\n            if self.prob[i] > np.random.random():\n                offset = np.dot(self.rgb_var, np.random.randn(3, 1))\n                # bgr to rgb\n                offset = offset[::-1]\n                offset = offset.reshape(3)\n                image = (image - offset).astype(np.float32)\n\n            _results['img'] = image\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqBlurAug(object):\n    \"\"\"Blur augmention for images.\n    Args:\n        prob (list[float]): The probability to perform blur augmention for\n            each image. Defaults to [0.0, 0.2].\n    \"\"\"\n\n    def __init__(self, prob=[0.0, 0.2]):\n        self.prob = prob\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, perform blur augmention for image in the\n        dict.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains augmented blur image.\n        \"\"\"\n        outs = []\n        for i, _results in enumerate(results):\n            image = _results['img']\n\n            if self.prob[i] > np.random.random():\n                sizes = np.arange(5, 46, 2)\n                size = np.random.choice(sizes)\n                kernel = np.zeros((size, size))\n                c = int(size / 2)\n                wx = np.random.random()\n                kernel[:, c] += 1. / size * wx\n                kernel[c, :] += 1. / size * (1 - wx)\n                image = cv2.filter2D(image, -1, kernel)\n\n            _results['img'] = image\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqResize(Resize):\n    \"\"\"Resize images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Resize` for\n    detailed docstring.\n    Args:\n        share_params (bool): If True, share the resize parameters for all\n            images. Defaults to True.\n    \"\"\"\n\n    def __init__(self, share_params=True, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.share_params = share_params\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Resize` to resize\n        image and corresponding annotations.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains resized results,\n            'img_shape', 'pad_shape', 'scale_factor', 'keep_ratio' keys\n            are added into result dict.\n        \"\"\"\n        outs, scale = [], None\n        for i, _results in enumerate(results):\n            if self.share_params and i > 0:\n                _results['scale'] = scale\n            _results = super().__call__(_results)\n            if self.share_params and i == 0:\n                scale = _results['scale']\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqNormalize(Normalize):\n    \"\"\"Normalize images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Normalize` for\n    detailed docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Normalize` to\n        normalize image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains normalized results,\n            'img_norm_cfg' key is added into result dict.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqRandomFlip(RandomFlip):\n    \"\"\"Randomly flip for images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:RandomFlip` for\n    detailed docstring.\n    Args:\n        share_params (bool): If True, share the flip parameters for all images.\n            Defaults to True.\n    \"\"\"\n\n    def __init__(self, share_params, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.share_params = share_params\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call `RandomFlip` to randomly flip image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains flipped results, 'flip',\n            'flip_direction' keys are added into the dict.\n        \"\"\"\n        if self.share_params:\n            if isinstance(self.direction, list):\n                # None means non-flip\n                direction_list = self.direction + [None]\n            else:\n                # None means non-flip\n                direction_list = [self.direction, None]\n\n            if isinstance(self.flip_ratio, list):\n                non_flip_ratio = 1 - sum(self.flip_ratio)\n                flip_ratio_list = self.flip_ratio + [non_flip_ratio]\n            else:\n                non_flip_ratio = 1 - self.flip_ratio\n                # exclude non-flip\n                single_ratio = self.flip_ratio / (len(direction_list) - 1)\n                flip_ratio_list = [single_ratio] * (len(direction_list) -\n                                                    1) + [non_flip_ratio]\n\n            cur_dir = np.random.choice(direction_list, p=flip_ratio_list)\n            flip = cur_dir is not None\n            flip_direction = cur_dir\n\n            for _results in results:\n                _results['flip'] = flip\n                _results['flip_direction'] = flip_direction\n\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqPad(Pad):\n    \"\"\"Pad images.\n    Please refer to `mmdet.datasets.pipelines.transfroms.py:Pad` for detailed\n    docstring.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def __call__(self, results):\n        \"\"\"Call function.\n        For each dict in results, call the call function of `Pad` to pad image.\n        Args:\n            results (list[dict]): List of dict that from\n                :obj:`mmtrack.CocoVideoDataset`.\n        Returns:\n            list[dict]: List of dict that contains padding results,\n            'pad_shape', 'pad_fixed_size' and 'pad_size_divisor' keys are\n            added into the dict.\n        \"\"\"\n        outs = []\n        for _results in results:\n            _results = super().__call__(_results)\n            outs.append(_results)\n        return outs\n\n\n@PIPELINES.register_module()\nclass SeqRandomCrop(object):\n    \"\"\"Sequentially random crop the images & bboxes & masks.\n    The absolute `crop_size` is sampled based on `crop_type` and `image_size`,\n    then the cropped results are generated.\n    Args:\n        crop_size (tuple): The relative ratio or absolute pixels of\n            height and width.\n        allow_negative_crop (bool, optional): Whether to allow a crop that does\n            not contain any bbox area. Default False.\n        share_params (bool, optional): Whether share the cropping parameters\n            for the images.\n        bbox_clip_border (bool, optional): Whether clip the objects outside\n            the border of the image. Defaults to True.\n    Note:\n        - If the image is smaller than the absolute crop size, return the\n            original image.\n        - The keys for bboxes, labels and masks must be aligned. That is,\n          `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and\n          `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and\n          `gt_masks_ignore`.\n        - If the crop does not contain any gt-bbox region and\n          `allow_negative_crop` is set to False, skip this image.\n    \"\"\"\n\n    def __init__(self,\n                 crop_size,\n                 allow_negative_crop=False,\n                 share_params=False,\n                 bbox_clip_border=True,\n                 check_id_match=True\n                 ):\n        assert crop_size[0] > 0 and crop_size[1] > 0\n        self.crop_size = crop_size\n        self.allow_negative_crop = allow_negative_crop\n        self.share_params = share_params\n        self.bbox_clip_border = bbox_clip_border\n        self.check_id_match = check_id_match\n        # The key correspondence from bboxes to labels and masks.\n        self.bbox2label = {\n            'gt_bboxes': ['gt_labels', 'gt_instance_ids'],\n            'gt_bboxes_ignore': ['gt_labels_ignore', 'gt_instance_ids_ignore']\n        }\n        self.bbox2mask = {\n            'gt_bboxes': 'gt_masks',\n            'gt_bboxes_ignore': 'gt_masks_ignore'\n        }\n\n    def get_offsets(self, img):\n        \"\"\"Random generate the offsets for cropping.\"\"\"\n        margin_h = max(img.shape[0] - self.crop_size[0], 0)\n        margin_w = max(img.shape[1] - self.crop_size[1], 0)\n        offset_h = np.random.randint(0, margin_h + 1)\n        offset_w = np.random.randint(0, margin_w + 1)\n        return offset_h, offset_w\n\n    def random_crop(self, results, offsets=None):\n        \"\"\"Call function to randomly crop images, bounding boxes, masks,\n        semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n            offsets (tuple, optional): Pre-defined offsets for cropping.\n                Default to None.\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n            updated according to crop size.\n        \"\"\"\n\n        for key in results.get('img_fields', ['img']):\n            img = results[key]\n            if offsets is not None:\n                offset_h, offset_w = offsets\n            else:\n                offset_h, offset_w = self.get_offsets(img)\n            results['img_info']['crop_offsets'] = (offset_h, offset_w)\n            crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]\n            crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]\n\n            # crop the image\n            img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]\n            img_shape = img.shape\n            results[key] = img\n        results['img_shape'] = img_shape\n\n        # crop bboxes accordingly and clip to the image boundary\n        for key in results.get('bbox_fields', []):\n            # e.g. gt_bboxes and gt_bboxes_ignore\n            bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],\n                                   dtype=np.float32)\n            bboxes = results[key] - bbox_offset\n            if self.bbox_clip_border:\n                bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])\n                bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])\n            valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (\n                    bboxes[:, 3] > bboxes[:, 1])\n            # If the crop does not contain any gt-bbox area and\n            # self.allow_negative_crop is False, skip this image.\n            if (key == 'gt_bboxes' and not valid_inds.any()\n                    and not self.allow_negative_crop):\n                return None\n            results[key] = bboxes[valid_inds, :]\n            # label fields. e.g. gt_labels and gt_labels_ignore\n            label_keys = self.bbox2label.get(key)\n            for label_key in label_keys:\n                if label_key in results:\n                    results[label_key] = results[label_key][valid_inds]\n\n            # mask fields, e.g. gt_masks and gt_masks_ignore\n            mask_key = self.bbox2mask.get(key)\n            if mask_key in results:\n                results[mask_key] = results[mask_key][\n                    valid_inds.nonzero()[0]].crop(\n                    np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))\n\n        # crop semantic seg\n        for key in results.get('seg_fields', []):\n            results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to sequentially randomly crop images, bounding boxes,\n        masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n            updated according to crop size.\n        \"\"\"\n        if self.share_params:\n            offsets = self.get_offsets(results[0]['img'])\n        else:\n            offsets = None\n\n        outs = []\n        for _results in results:\n            _results = self.random_crop(_results, offsets)\n            if _results is None:\n                return None\n            outs.append(_results)\n\n        if len(outs) == 2 and self.check_id_match:\n            ref_result, result = outs[1], outs[0]\n            if self.check_match(ref_result, result):\n                return None\n        return outs\n\n    def check_match(self, ref_results, results):\n        ref_ids = ref_results['gt_instance_ids'].tolist()\n        gt_ids = results['gt_instance_ids'].tolist()\n        gt_pids = [ref_ids.index(i) if i in ref_ids else -1 for i in gt_ids]\n        nomatch = (np.array(gt_pids) == -1).all()\n        return nomatch\n\n\n@PIPELINES.register_module()\nclass SeqPhotoMetricDistortion(object):\n    \"\"\"Apply photometric distortion to image sequentially, every transformation\n    is applied with a probability of 0.5. The position of random contrast is in\n    second or second to last.\n    1. random brightness\n    2. random contrast (mode 0)\n    3. convert color from BGR to HSV\n    4. random saturation\n    5. random hue\n    6. convert color from HSV to BGR\n    7. random contrast (mode 1)\n    8. randomly swap channels\n    Args:\n        brightness_delta (int): delta of brightness.\n        contrast_range (tuple): range of contrast.\n        saturation_range (tuple): range of saturation.\n        hue_delta (int): delta of hue.\n    \"\"\"\n\n    def __init__(self,\n                 share_params=True,\n                 brightness_delta=32,\n                 contrast_range=(0.5, 1.5),\n                 saturation_range=(0.5, 1.5),\n                 hue_delta=18):\n        self.share_params = share_params\n        self.brightness_delta = brightness_delta\n        self.contrast_lower, self.contrast_upper = contrast_range\n        self.saturation_lower, self.saturation_upper = saturation_range\n        self.hue_delta = hue_delta\n\n    def get_params(self):\n        \"\"\"Generate parameters.\"\"\"\n        params = dict()\n        # delta\n        if np.random.randint(2):\n            params['delta'] = np.random.uniform(-self.brightness_delta,\n                                                self.brightness_delta)\n        else:\n            params['delta'] = None\n        # mode\n        mode = np.random.randint(2)\n        params['contrast_first'] = True if mode == 1 else 0\n        # alpha\n        if np.random.randint(2):\n            params['alpha'] = np.random.uniform(self.contrast_lower,\n                                                self.contrast_upper)\n        else:\n            params['alpha'] = None\n        # saturation\n        if np.random.randint(2):\n            params['saturation'] = np.random.uniform(self.saturation_lower,\n                                                     self.saturation_upper)\n        else:\n            params['saturation'] = None\n        # hue\n        if np.random.randint(2):\n            params['hue'] = np.random.uniform(-self.hue_delta, self.hue_delta)\n        else:\n            params['hue'] = None\n        # swap\n        if np.random.randint(2):\n            params['permutation'] = np.random.permutation(3)\n        else:\n            params['permutation'] = None\n        return params\n\n    def photo_metric_distortion(self, results, params=None):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n            params (dict, optional): Pre-defined parameters. Default to None.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        if params is None:\n            params = self.get_params()\n        results['img_info']['color_jitter'] = params\n\n        if 'img_fields' in results:\n            assert results['img_fields'] == ['img'], \\\n                'Only single img_fields is allowed'\n        img = results['img']\n        assert img.dtype == np.float32, \\\n            'PhotoMetricDistortion needs the input image of dtype np.float32,' \\\n            ' please set \"to_float32=True\" in \"LoadImageFromFile\" pipeline'\n        # random brightness\n        if params['delta'] is not None:\n            img += params['delta']\n\n        # mode == 0 --> do random contrast first\n        # mode == 1 --> do random contrast last\n        if params['contrast_first']:\n            if params['alpha'] is not None:\n                img *= params['alpha']\n\n        # convert color from BGR to HSV\n        img = mmcv.bgr2hsv(img)\n\n        # random saturation\n        if params['saturation'] is not None:\n            img[..., 1] *= params['saturation']\n\n        # random hue\n        if params['hue'] is not None:\n            img[..., 0] += params['hue']\n            img[..., 0][img[..., 0] > 360] -= 360\n            img[..., 0][img[..., 0] < 0] += 360\n\n        # convert color from HSV to BGR\n        img = mmcv.hsv2bgr(img)\n\n        # random contrast\n        if not params['contrast_first']:\n            if params['alpha'] is not None:\n                img *= params['alpha']\n\n        # randomly swap channels\n        if params['permutation'] is not None:\n            img = img[..., params['permutation']]\n\n        results['img'] = img\n        return results\n\n    def __call__(self, results):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        if self.share_params:\n            params = self.get_params()\n        else:\n            params = None\n\n        outs = []\n        for _results in results:\n            _results = self.photo_metric_distortion(_results, params)\n            outs.append(_results)\n\n        return outs\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(\\nbrightness_delta={self.brightness_delta},\\n'\n        repr_str += 'contrast_range='\n        repr_str += f'{(self.contrast_lower, self.contrast_upper)},\\n'\n        repr_str += 'saturation_range='\n        repr_str += f'{(self.saturation_lower, self.saturation_upper)},\\n'\n        repr_str += f'hue_delta={self.hue_delta})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass ResizeWithRef(object):\n    \"\"\"Resize images & bbox & mask.\n\n    This transform resizes the input image to some scale. Bboxes and masks are\n    then resized with the same scale factor. If the input dict contains the key\n    \"scale\", then the scale in the input dict is used, otherwise the specified\n    scale in the init method is used.\n\n    `img_scale` can either be a tuple (single-scale) or a list of tuple\n    (multi-scale). There are 3 multiscale modes:\n    - `ratio_range` is not None: randomly sample a ratio from the ratio range\n        and multiply it with the image scale.\n    - `ratio_range` is None and `multiscale_mode` == \"range\": randomly sample a\n        scale from the a range.\n    - `ratio_range` is None and `multiscale_mode` == \"value\": randomly sample a\n        scale from multiple scales.\n\n    Args:\n        img_scale (tuple or list[tuple]): Images scales for resizing.\n        multiscale_mode (str): Either \"range\" or \"value\".\n        ratio_range (tuple[float]): (min_ratio, max_ratio)\n        keep_ratio (bool): Whether to keep the aspect ratio when resizing the\n            image.\n    \"\"\"\n\n    def __init__(self,\n                 img_scale=None,\n                 multiscale_mode='range',\n                 ratio_range=None,\n                 keep_ratio=True):\n        if img_scale is None:\n            self.img_scale = None\n        else:\n            if isinstance(img_scale, list):\n                self.img_scale = img_scale\n            else:\n                self.img_scale = [img_scale]\n            assert mmcv.is_list_of(self.img_scale, tuple)\n\n        if ratio_range is not None:\n            # mode 1: given a scale and a range of image ratio\n            assert len(self.img_scale) == 1\n        else:\n            # mode 2: given multiple scales or a range of scales\n            assert multiscale_mode in ['value', 'range']\n\n        self.multiscale_mode = multiscale_mode\n        self.ratio_range = ratio_range\n        self.keep_ratio = keep_ratio\n\n    @staticmethod\n    def random_select(img_scales):\n        assert mmcv.is_list_of(img_scales, tuple)\n        scale_idx = np.random.randint(len(img_scales))\n        img_scale = img_scales[scale_idx]\n        return img_scale, scale_idx\n\n    @staticmethod\n    def random_sample(img_scales):\n        assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2\n        img_scale_long = [max(s) for s in img_scales]\n        img_scale_short = [min(s) for s in img_scales]\n        long_edge = np.random.randint(\n            min(img_scale_long),\n            max(img_scale_long) + 1)\n        short_edge = np.random.randint(\n            min(img_scale_short),\n            max(img_scale_short) + 1)\n        img_scale = (long_edge, short_edge)\n        return img_scale, None\n\n    @staticmethod\n    def random_sample_ratio(img_scale, ratio_range):\n        assert isinstance(img_scale, tuple) and len(img_scale) == 2\n        min_ratio, max_ratio = ratio_range\n        assert min_ratio <= max_ratio\n        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio\n        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)\n        return scale, None\n\n    def _random_scale(self, results):\n        if self.ratio_range is not None:\n            scale, scale_idx = self.random_sample_ratio(\n                self.img_scale[0], self.ratio_range)\n        elif len(self.img_scale) == 1:\n            scale, scale_idx = self.img_scale[0], 0\n        elif self.multiscale_mode == 'range':\n            scale, scale_idx = self.random_sample(self.img_scale)\n        elif self.multiscale_mode == 'value':\n            scale, scale_idx = self.random_select(self.img_scale)\n        else:\n            raise NotImplementedError\n\n        results['scale'] = scale\n        results['scale_idx'] = scale_idx\n\n    def _resize_img(self, results):\n        els = ['ref_img', 'img'] if 'ref_img' in results else ['img']\n        for el in els:\n            if self.keep_ratio:\n                img, scale_factor = mmcv.imrescale(\n                    results[el], results['scale'], return_scale=True)\n            else:\n                img, w_scale, h_scale = mmcv.imresize(\n                    results[el], results['scale'], return_scale=True)\n                scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],\n                                        dtype=np.float32)\n            results[el] = img\n        results['img_shape'] = img.shape\n        results['pad_shape'] = img.shape  # in case that there is no padding\n        results['scale_factor'] = scale_factor\n        results['keep_ratio'] = self.keep_ratio\n\n    def _resize_bboxes(self, results):\n        els = ['ref_bbox_fields', 'bbox_fields'] if 'ref_bbox_fields' in results else ['bbox_fields']\n        for el in els:\n            img_shape = results['img_shape']\n            for key in results.get(el, []):\n                bboxes = results[key] * results['scale_factor']\n                bboxes[:, 0::2] = np.clip(\n                    bboxes[:, 0::2], 0, img_shape[1] - 1)\n                bboxes[:, 1::2] = np.clip(\n                    bboxes[:, 1::2], 0, img_shape[0] - 1)\n                results[key] = bboxes\n\n    def _resize_masks(self, results):\n        els = ['ref_mask_fields', 'mask_fields'] if 'ref_mask_fields' in results else ['mask_fields']\n        for el in els:\n            for key in results.get(el, []):\n                if results[key] is None:\n                    continue\n                if self.keep_ratio:\n                    masks = [\n                        mmcv.imrescale(\n                            mask, results['scale_factor'],\n                            interpolation='nearest')\n                        for mask in results[key]\n                    ]\n                else:\n                    mask_size = (results['img_shape'][1],\n                                 results['img_shape'][0])\n                    masks = [\n                        mmcv.imresize(mask, mask_size,\n                                      interpolation='nearest')\n                        for mask in results[key]\n                    ]\n                results[key] = masks\n\n    def __call__(self, results):\n        if 'scale' not in results:\n            self._random_scale(results)\n        self._resize_img(results)\n        self._resize_bboxes(results)\n        self._resize_masks(results)\n        # self._resize_semantic_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += ('(img_scale={}, multiscale_mode={}, ratio_range={}, '\n                     'keep_ratio={})').format(self.img_scale,\n                                              self.multiscale_mode,\n                                              self.ratio_range,\n                                              self.keep_ratio)\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomFlipWithRef(object):\n    \"\"\"Flip the image & bbox & mask.\n\n    If the input dict contains the key \"flip\", then the flag will be used,\n    otherwise it will be randomly decided by a ratio specified in the init\n    method.\n\n    Args:\n        flip_ratio (float, optional): The flipping probability.\n    \"\"\"\n\n    def __init__(self, flip_ratio=None):\n        self.flip_ratio = flip_ratio\n        if flip_ratio is not None:\n            assert flip_ratio >= 0 and flip_ratio <= 1\n\n    def bbox_flip(self, bboxes, img_shape):\n        \"\"\"Flip bboxes horizontally.\n\n        Args:\n            bboxes(ndarray): shape (..., 4*k)\n            img_shape(tuple): (height, width)\n        \"\"\"\n        assert bboxes.shape[-1] % 4 == 0\n        w = img_shape[1]\n        flipped = bboxes.copy()\n        flipped[..., 0::4] = w - bboxes[..., 2::4] - 1\n        flipped[..., 2::4] = w - bboxes[..., 0::4] - 1\n        return flipped\n\n    def __call__(self, results):\n        if 'flip' not in results:\n            flip = True if np.random.rand() < self.flip_ratio else False\n            results['flip'] = flip\n        if results['flip']:\n            # flip image\n            results['img'] = mmcv.imflip(results['img'])\n            if 'ref_img' in results:\n                results['ref_img'] = mmcv.imflip(results['ref_img'])\n            # flip bboxes\n            for key in results.get('bbox_fields', []):\n                results[key] = self.bbox_flip(results[key],\n                                              results['img_shape'])\n            for key in results.get('ref_bbox_fields', []):\n                results[key] = self.bbox_flip(results[key],\n                                              results['img_shape'])\n            # flip masks\n            for key in results.get('mask_fields', []):\n                results[key] = [mask[:, ::-1] for mask in results[key]]\n            for key in results.get('ref_mask_fields', []):\n                results[key] = [mask[:, ::-1] for mask in results[key]]\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(flip_ratio={})'.format(\n            self.flip_ratio)\n\n\n@PIPELINES.register_module()\nclass PadWithRef(object):\n    \"\"\"Pad the image & mask.\n\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_val (float, optional): Padding value, 0 by default.\n    \"\"\"\n\n    def __init__(self, size=None, size_divisor=None, pad_val=0):\n        self.size = size\n        self.size_divisor = size_divisor\n        self.pad_val = pad_val\n        # only one of size and size_divisor should be valid\n        assert size is not None or size_divisor is not None\n        assert size is None or size_divisor is None\n\n    def _pad_img(self, results):\n        els = ['ref_img', 'img'] if 'ref_img' in results else ['img']\n        for el in els:\n            if self.size is not None:\n                padded_img = mmcv.impad(results['img'], self.size)\n            elif self.size_divisor is not None:\n                padded_img = mmcv.impad_to_multiple(\n                    results[el], self.size_divisor, pad_val=self.pad_val)\n            results[el] = padded_img\n        results['pad_shape'] = padded_img.shape\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def _pad_masks(self, results):\n        els = ['ref_mask_fields', 'mask_fields'] if 'ref_mask_fields' in results else ['mask_fields']\n        for el in els:\n            pad_shape = results['pad_shape'][:2]\n            for key in results.get(el, []):\n                padded_masks = [\n                    mmcv.impad(mask, pad_shape, pad_val=self.pad_val)\n                    for mask in results[key]\n                ]\n                results[key] = np.stack(padded_masks, axis=0)\n\n    def __call__(self, results):\n        self._pad_img(results)\n        self._pad_masks(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += '(size={}, size_divisor={}, pad_val={})'.format(\n            self.size, self.size_divisor, self.pad_val)\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass NormalizeWithRef(object):\n    \"\"\"Normalize the image.\n\n    Args:\n        mean (sequence): Mean values of 3 channels.\n        std (sequence): Std values of 3 channels.\n        to_rgb (bool): Whether to convert the image from BGR to RGB,\n            default is true.\n    \"\"\"\n\n    def __init__(self, mean, std, to_rgb=True):\n        self.mean = np.array(mean, dtype=np.float32)\n        self.std = np.array(std, dtype=np.float32)\n        self.to_rgb = to_rgb\n\n    def __call__(self, results):\n        results['img'] = mmcv.imnormalize(\n            results['img'], self.mean, self.std, self.to_rgb)\n        if 'ref_img' in results:\n            results['ref_img'] = mmcv.imnormalize(\n                results['ref_img'], self.mean, self.std, self.to_rgb)\n        results['img_norm_cfg'] = dict(\n            mean=self.mean, std=self.std, to_rgb=self.to_rgb)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += '(mean={}, std={}, to_rgb={})'.format(\n            self.mean, self.std, self.to_rgb)\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomCropWithRef(object):\n    \"\"\"Random crop the image & bboxes & masks.\n\n    Args:\n        crop_size (tuple): Expected size after cropping, (h, w).\n    \"\"\"\n\n    def __init__(self, crop_size):\n        self.crop_size = crop_size\n\n    def __call__(self, results):\n        img = results['img']\n\n        margin_h = max(img.shape[0] - self.crop_size[0], 0)\n        margin_w = max(img.shape[1] - self.crop_size[1], 0)\n        offset_h = np.random.randint(0, margin_h + 1)\n        offset_w = np.random.randint(0, margin_w + 1)\n        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]\n        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]\n\n        # crop the image\n        ori_shape = img.shape\n        img = img[crop_y1:crop_y2, crop_x1:crop_x2, :]\n        img_shape = img.shape\n        results['img'] = img\n        if 'ref_img' in results:\n            ref_img = results['ref_img']\n            ref_img = ref_img[crop_y1:crop_y2, crop_x1:crop_x2, :]\n            results['ref_img'] = ref_img\n        results['img_shape'] = img_shape\n        results['crop_coords'] = [crop_y1, crop_y2, crop_x1, crop_x2]\n\n        # crop bboxes accordingly and clip to the image boundary\n        els = ['ref_bbox_fields', 'bbox_fields'] if 'ref_bbox_fields' in results else ['bbox_fields']\n        for el in els:\n            for key in results.get(el, []):\n                bbox_offset = np.array(\n                    [offset_w, offset_h, offset_w, offset_h],\n                    dtype=np.float32)\n                bboxes = results[key] - bbox_offset\n                bboxes[:, 0::2] = np.clip(\n                    bboxes[:, 0::2], 0, img_shape[1] - 1)\n                bboxes[:, 1::2] = np.clip(\n                    bboxes[:, 1::2], 0, img_shape[0] - 1)\n                results[key] = bboxes\n\n        # filter out the gt bboxes that are completely cropped\n        els = ['ref_bboxes', 'gt_bboxes'] if 'ref_bboxes' in results else ['gt_bboxes']\n        for el in els:\n            if el in results:\n                gt_bboxes = results[el]\n                valid_inds = (gt_bboxes[:, 2] > gt_bboxes[:, 0]) & (\n                        gt_bboxes[:, 3] > gt_bboxes[:, 1])\n                # if no gt bbox remains after cropping, just skip this image\n                if not np.any(valid_inds):\n                    return None\n                results[el] = gt_bboxes[valid_inds, :]\n                ell = el.replace('_bboxes', '_labels')\n                if ell in results:\n                    results[ell] = results[ell][valid_inds]\n                #### filter gt_obj_ids just like gt_labes.\n                elo = el.replace('_bboxes', '_obj_ids')\n                if elo in results:\n                    results[elo] = results[elo][valid_inds]\n                # filter and crop the masks\n                elm = el.replace('_bboxes', '_masks')\n                if elm in results:\n                    valid_gt_masks = []\n                    for i in np.where(valid_inds)[0]:\n                        gt_mask = results[elm][i][\n                                  crop_y1:crop_y2, crop_x1:crop_x2]\n                        valid_gt_masks.append(gt_mask)\n                    results[elm] = valid_gt_masks\n\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(crop_size={})'.format(\n            self.crop_size)\n\n\n@PIPELINES.register_module()\nclass PadFutureMMDet:\n    \"\"\"Pad the image & masks & segmentation map.\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n    Added keys are \"pad_shape\", \"pad_fixed_size\", \"pad_size_divisor\",\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_to_square (bool): Whether to pad the image into a square.\n            Currently only used for YOLOX. Default: False.\n        pad_val (dict, optional): A dict for padding value, the default\n            value is `dict(img=0, masks=0, seg=255)`.\n    \"\"\"\n\n    def __init__(self,\n                 size=None,\n                 size_divisor=None,\n                 pad_to_square=False,\n                 pad_val=dict(img=0, masks=0, seg=255)):\n        self.size = size\n        self.size_divisor = size_divisor\n        if isinstance(pad_val, float) or isinstance(pad_val, int):\n            warnings.warn(\n                'pad_val of float type is deprecated now, '\n                f'please use pad_val=dict(img={pad_val}, '\n                f'masks={pad_val}, seg=255) instead.', DeprecationWarning)\n            pad_val = dict(img=pad_val, masks=pad_val, seg=255)\n        assert isinstance(pad_val, dict)\n        self.pad_val = pad_val\n        self.pad_to_square = pad_to_square\n\n        if pad_to_square:\n            assert size is None and size_divisor is None, \\\n                'The size and size_divisor must be None ' \\\n                'when pad2square is True'\n        else:\n            assert size is not None or size_divisor is not None, \\\n                'only one of size and size_divisor should be valid'\n            assert size is None or size_divisor is None\n\n    def _pad_img(self, results):\n        \"\"\"Pad images according to ``self.size``.\"\"\"\n        pad_val = self.pad_val.get('img', 0)\n        for key in results.get('img_fields', ['img']):\n            if self.pad_to_square:\n                max_size = max(results[key].shape[:2])\n                self.size = (max_size, max_size)\n            if self.size is not None:\n                padded_img = mmcv.impad(\n                    results[key], shape=self.size, pad_val=pad_val)\n            elif self.size_divisor is not None:\n                padded_img = mmcv.impad_to_multiple(\n                    results[key], self.size_divisor, pad_val=pad_val)\n            results[key] = padded_img\n        results['pad_shape'] = padded_img.shape\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def _pad_masks(self, results):\n        \"\"\"Pad masks according to ``results['pad_shape']``.\"\"\"\n        pad_shape = results['pad_shape'][:2]\n        pad_val = self.pad_val.get('masks', 0)\n        for key in results.get('mask_fields', []):\n            results[key] = results[key].pad(pad_shape, pad_val=pad_val)\n\n    def _pad_seg(self, results):\n        \"\"\"Pad semantic segmentation map according to\n        ``results['pad_shape']``.\"\"\"\n        pad_val = self.pad_val.get('seg', 255)\n        for key in results.get('seg_fields', []):\n            results[key] = mmcv.impad(\n                results[key], shape=results['pad_shape'][:2], pad_val=pad_val)\n\n    def __call__(self, results):\n        \"\"\"Call function to pad images, masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        self._pad_img(results)\n        self._pad_masks(results)\n        self._pad_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(size={self.size}, '\n        repr_str += f'size_divisor={self.size_divisor}, '\n        repr_str += f'pad_to_square={self.pad_to_square}, '\n        repr_str += f'pad_val={self.pad_val})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass KNetInsAdapter:\n    \"\"\"Adapter that is used to convert city-style instance class-ids\n    to coco-style instance-ids (11-starting to 0-starting)\n    \"\"\"\n\n    def __init__(self, stuff_nums=11):\n        self.stuff_nums = stuff_nums\n\n    def __call__(self, results):\n        \"\"\"Call function to modify gt_labels\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        results['gt_labels'] -= self.stuff_nums\n        return results\n\n\n@PIPELINES.register_module()\nclass KNetInsAdapterCherryPick:\n    \"\"\"Adapter that is used to convert city-style instance class-ids\n    to coco-style instance-ids (11-starting to 0-starting)\n    \"\"\"\n\n    def __init__(self, stuff_nums=11, cherry=(11, 13)):\n        self.cherry = cherry\n        self.stuff_nums = stuff_nums\n\n    def __call__(self, results):\n        \"\"\"Call function to modify gt_labels\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        bias = 0\n        for ch in self.cherry:\n            results['gt_labels'][results['gt_labels'] == ch] -= bias\n            bias += 1\n        results['gt_labels'] -= self.stuff_nums\n        return results\n"
  },
  {
    "path": "mmtrack/transform.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport numpy as np\nimport torch\nfrom mmdet.core import bbox2result\n\ndef outs2results(bboxes=None,\n                 labels=None,\n                 masks=None,\n                 ids=None,\n                 num_classes=None,\n                 **kwargs):\n    \"\"\"Convert tracking/detection results to a list of numpy arrays.\n    Args:\n        bboxes (torch.Tensor | np.ndarray): shape (n, 5)\n        labels (torch.Tensor | np.ndarray): shape (n, )\n        masks (torch.Tensor | np.ndarray): shape (n, h, w)\n        ids (torch.Tensor | np.ndarray): shape (n, )\n        num_classes (int): class number, not including background class\n    Returns:\n        dict[str : list(ndarray) | list[list[np.ndarray]]]: tracking/detection\n        results of each class. It may contain keys as belows:\n        - bbox_results (list[np.ndarray]): Each list denotes bboxes of one\n            category.\n        - mask_results (list[list[np.ndarray]]): Each outer list denotes masks\n            of one category. Each inner list denotes one mask belonging to\n            the category. Each mask has shape (h, w).\n    \"\"\"\n    assert labels is not None\n    assert num_classes is not None\n\n    results = dict()\n\n    if ids is not None:\n        valid_inds = ids > -1\n        ids = ids[valid_inds]\n        labels = labels[valid_inds]\n\n    if bboxes is not None:\n        if ids is not None:\n            bboxes = bboxes[valid_inds]\n            if bboxes.shape[0] == 0:\n                bbox_results = [\n                    np.zeros((0, 6), dtype=np.float32)\n                    for i in range(num_classes)\n                ]\n            else:\n                if isinstance(bboxes, torch.Tensor):\n                    bboxes = bboxes.cpu().numpy()\n                    labels = labels.cpu().numpy()\n                    ids = ids.cpu().numpy()\n                bbox_results = [\n                    np.concatenate(\n                        (ids[labels == i, None], bboxes[labels == i, :]),\n                        axis=1) for i in range(num_classes)\n                ]\n        else:\n            bbox_results = bbox2result(bboxes, labels, num_classes)\n        results['bbox_results'] = bbox_results\n\n    if masks is not None:\n        if ids is not None:\n            masks = masks[valid_inds]\n        if isinstance(masks, torch.Tensor):\n            masks = masks.detach().cpu().numpy()\n        masks_results = [[] for _ in range(num_classes)]\n        for i in range(bboxes.shape[0]):\n            masks_results[labels[i]].append(masks[i])\n        results['mask_results'] = masks_results\n\n    return results"
  },
  {
    "path": "scripts/kitti_step_prepare.py",
    "content": "import os\nimport shutil\n\ntrain_seqs = [0, 1, 3, 4, 5, 9, 11, 12, 15, 17, 19, 20]\nval_seqs = [2, 6, 7, 8, 10, 13, 14, 16, 18]\ntest_seqs = list(range(29))\n\n# your download the KITTI STEP dataset.\ndata_root = os.path.expanduser('/data/data1/datasets/STEP/kitti/training/')\ndata_root_test = os.path.expanduser('/data/data1/datasets/STEP/kitti/testing/')\ndata_out = os.path.expanduser('/data/data1/datasets/STEP/kitti_out')\n\n\ndef build_panoptic(seq_id, input_dir, output_dir):\n    input_panoptic_dir = os.path.join(input_dir, '{:04d}'.format(seq_id))\n    print(\"Preparing seq id : {}\".format(seq_id))\n    panoptic_files = sorted(list(map(lambda x: str(x), os.listdir(input_panoptic_dir))))\n\n    print(\"Dst dir is {}\".format(output_dir))\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    for file in panoptic_files:\n        print(os.path.join(output_dir, '{:06d}_{}_panoptic.png'.format(seq_id, file.split('.')[0])))\n        shutil.move(os.path.join(input_panoptic_dir, file),\n                    os.path.join(output_dir, '{:06d}_{}_panoptic.png'.format(seq_id, file.split('.')[0])))\n\n\ndef build_img(seq_id, input_dir, output_dir):\n    input_panoptic_dir = os.path.join(input_dir, '{:04d}'.format(seq_id))\n    print(\"Preparing seq id : {}\".format(seq_id))\n    panoptic_files = sorted(list(map(lambda x: str(x), os.listdir(input_panoptic_dir))))\n\n    print(\"Dst dir is {}\".format(output_dir))\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    for file in panoptic_files:\n        print(os.path.join(output_dir, '{:06d}_{}_leftImg8bit.png'.format(seq_id, file.split('.')[0])))\n        shutil.move(os.path.join(input_panoptic_dir, file),\n                    os.path.join(output_dir, '{:06d}_{}_leftImg8bit.png'.format(seq_id, file.split('.')[0])))\n\n\nif __name__ == '__main__':\n    for seq_id in train_seqs:\n        build_panoptic(seq_id, os.path.join(data_root, 'panoptic'), os.path.join(data_out, 'video_sequence', 'train'))\n\n    for seq_id in val_seqs:\n        build_panoptic(seq_id, os.path.join(data_root, 'panoptic'), os.path.join(data_out, 'video_sequence', 'val'))\n\n    for seq_id in train_seqs:\n        build_img(seq_id, os.path.join(data_root, 'image_02'), os.path.join(data_out, 'video_sequence', 'train'))\n\n    for seq_id in val_seqs:\n        build_img(seq_id, os.path.join(data_root, 'image_02'), os.path.join(data_out, 'video_sequence', 'val'))\n\n    for seq_id in test_seqs:\n        build_img(seq_id, os.path.join(data_root_test, 'image_02'), os.path.join(data_out, 'video_sequence', 'test'))"
  },
  {
    "path": "scripts/visualizer.py",
    "content": "import hashlib\nimport numpy as np\nimport cv2\n\ncity_labels = [\n    ('road', 0, (128, 64, 128)),\n    ('sidewalk', 1, (244, 35, 232)),\n    ('building', 2, (70, 70, 70)),\n    ('wall', 3, (102, 102, 156)),\n    ('fence', 4, (190, 153, 153)),\n    ('pole', 5, (153, 153, 153)),\n    ('traffic light', 6, (250, 170, 30)),\n    ('traffic sign', 7, (220, 220, 0)),\n    ('vegetation', 8, (107, 142, 35)),\n    ('terrain', 9, (152, 251, 152)),\n    ('sky', 10, (70, 130, 180)),\n    ('person', 11, (220, 20, 60)),\n    ('rider', 12, (255, 0, 0)),\n    ('car', 13, (0, 0, 142)),\n    ('truck', 14, (0, 0, 70)),\n    ('bus', 15, (0, 60, 100)),\n    ('train', 16, (0, 80, 100)),\n    ('motorcycle', 17, (0, 0, 230)),\n    ('bicycle', 18, (119, 11, 32)),\n    ('void', 19, (0, 0, 0)),\n    ('void', 255, (0, 0, 0))\n]\n\n\ndef sha256num(num):\n    hex = hashlib.sha256(str(num).encode('utf-8')).hexdigest()\n    hex = hex[-6:]\n    return int(hex, 16)\n\n\ndef id2rgb(id_map):\n    if isinstance(id_map, np.ndarray):\n        id_map_copy = id_map.copy()\n        rgb_shape = tuple(list(id_map.shape) + [3])\n        rgb_map = np.zeros(rgb_shape, dtype=np.uint8)\n        for i in range(3):\n            rgb_map[..., i] = id_map_copy % 256\n            id_map_copy //= 256\n        return rgb_map\n    color = []\n    for _ in range(3):\n        color.append(id_map % 256)\n        id_map //= 256\n    return color\n\n\ndef cityscapes_cat2rgb(cat_map):\n    color_map = np.zeros_like(cat_map).astype(np.uint8)\n    color_map = color_map[..., None].repeat(3, axis=-1)\n    for each_class in city_labels:\n        index = cat_map == each_class[1]\n        if index.any():\n            color_map[index] = each_class[2]\n    return color_map\n\n\ndef trackmap2rgb(track_map):\n    color_map = np.zeros_like(track_map).astype(np.uint8)\n    color_map = color_map[..., None].repeat(3, axis=-1)\n    for id_cur in np.unique(track_map):\n        if id_cur == 0:\n            continue\n        color_map[track_map == id_cur] = id2rgb(sha256num(id_cur))\n    return color_map\n\n\ndef draw_bbox_on_img(vis_img, bboxes):\n    for index in range(bboxes.shape[0]):\n        cv2.rectangle(vis_img, (int(bboxes[index][0]), int(bboxes[index][1])),\n                      (int(bboxes[index][2]), int(bboxes[index][3])), (0, 0, 255), thickness=1)\n    return vis_img\n"
  },
  {
    "path": "swin/DetectRS.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch.nn as nn\nimport torch.utils.checkpoint as cp\nfrom mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,\n                      kaiming_init)\nfrom mmcv.runner import Sequential, load_checkpoint\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nfrom mmdet.utils import get_root_logger\nfrom mmdet.models.builder import BACKBONES\nfrom mmdet.models.backbones.resnet import BasicBlock\nfrom mmdet.models.backbones.resnet import Bottleneck as _Bottleneck\nfrom mmdet.models.backbones.resnet import ResNet\n\n\nclass Bottleneck(_Bottleneck):\n    r\"\"\"Bottleneck for the ResNet backbone in `DetectoRS\n    <https://arxiv.org/pdf/2006.02334.pdf>`_.\n    This bottleneck allows the users to specify whether to use\n    SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid).\n    Args:\n         inplanes (int): The number of input channels.\n         planes (int): The number of output channels before expansion.\n         rfp_inplanes (int, optional): The number of channels from RFP.\n             Default: None. If specified, an additional conv layer will be\n             added for ``rfp_feat``. Otherwise, the structure is the same as\n             base class.\n         sac (dict, optional): Dictionary to construct SAC. Default: None.\n         init_cfg (dict or list[dict], optional): Initialization config dict.\n            Default: None\n    \"\"\"\n    expansion = 4\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 rfp_inplanes=None,\n                 sac=None,\n                 init_cfg=None,\n                 **kwargs):\n        super(Bottleneck, self).__init__(\n            inplanes, planes, init_cfg=init_cfg, **kwargs)\n\n        assert sac is None or isinstance(sac, dict)\n        self.sac = sac\n        self.with_sac = sac is not None\n        if self.with_sac:\n            self.conv2 = build_conv_layer(\n                self.sac,\n                planes,\n                planes,\n                kernel_size=3,\n                stride=self.conv2_stride,\n                padding=self.dilation,\n                dilation=self.dilation,\n                bias=False)\n\n        self.rfp_inplanes = rfp_inplanes\n        if self.rfp_inplanes:\n            self.rfp_conv = build_conv_layer(\n                None,\n                self.rfp_inplanes,\n                planes * self.expansion,\n                1,\n                stride=1,\n                bias=True)\n            # TODO : Is this a bug ?\n            if init_cfg is None:\n                self.init_cfg = dict(\n                    type='Constant', val=0, override=dict(name='rfp_conv'))\n\n    def rfp_forward(self, x, rfp_feat):\n        \"\"\"The forward function that also takes the RFP features as input.\"\"\"\n\n        def _inner_forward(x):\n            identity = x\n\n            out = self.conv1(x)\n            out = self.norm1(out)\n            out = self.relu(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv1_plugin_names)\n\n            out = self.conv2(out)\n            out = self.norm2(out)\n            out = self.relu(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv2_plugin_names)\n\n            out = self.conv3(out)\n            out = self.norm3(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv3_plugin_names)\n\n            if self.downsample is not None:\n                identity = self.downsample(x)\n\n            out += identity\n\n            return out\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        if self.rfp_inplanes:\n            rfp_feat = self.rfp_conv(rfp_feat)\n            out = out + rfp_feat\n\n        out = self.relu(out)\n\n        return out\n\n\nclass ResLayer(Sequential):\n    \"\"\"ResLayer to build ResNet style backbone for RPF in detectoRS.\n    The difference between this module and base class is that we pass\n    ``rfp_inplanes`` to the first block.\n    Args:\n        block (nn.Module): block used to build ResLayer.\n        inplanes (int): inplanes of block.\n        planes (int): planes of block.\n        num_blocks (int): number of blocks.\n        stride (int): stride of the first block. Default: 1\n        avg_down (bool): Use AvgPool instead of stride conv when\n            downsampling in the bottleneck. Default: False\n        conv_cfg (dict): dictionary to construct and config conv layer.\n            Default: None\n        norm_cfg (dict): dictionary to construct and config norm layer.\n            Default: dict(type='BN')\n        downsample_first (bool): Downsample at the first block or last block.\n            False for Hourglass, True for ResNet. Default: True\n        rfp_inplanes (int, optional): The number of channels from RFP.\n            Default: None. If specified, an additional conv layer will be\n            added for ``rfp_feat``. Otherwise, the structure is the same as\n            base class.\n    \"\"\"\n\n    def __init__(self,\n                 block,\n                 inplanes,\n                 planes,\n                 num_blocks,\n                 stride=1,\n                 avg_down=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 downsample_first=True,\n                 rfp_inplanes=None,\n                 **kwargs):\n        self.block = block\n        assert downsample_first, f'downsample_first={downsample_first} is ' \\\n                                 'not supported in DetectoRS'\n\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = []\n            conv_stride = stride\n            if avg_down and stride != 1:\n                conv_stride = 1\n                downsample.append(\n                    nn.AvgPool2d(\n                        kernel_size=stride,\n                        stride=stride,\n                        ceil_mode=True,\n                        count_include_pad=False))\n            downsample.extend([\n                build_conv_layer(\n                    conv_cfg,\n                    inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=conv_stride,\n                    bias=False),\n                build_norm_layer(norm_cfg, planes * block.expansion)[1]\n            ])\n            downsample = nn.Sequential(*downsample)\n\n        layers = []\n        layers.append(\n            block(\n                inplanes=inplanes,\n                planes=planes,\n                stride=stride,\n                downsample=downsample,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                rfp_inplanes=rfp_inplanes,\n                **kwargs))\n        inplanes = planes * block.expansion\n        for _ in range(1, num_blocks):\n            layers.append(\n                block(\n                    inplanes=inplanes,\n                    planes=planes,\n                    stride=1,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    **kwargs))\n\n        super(ResLayer, self).__init__(*layers)\n\n\n@BACKBONES.register_module()\nclass DetectoRS_ResNet_Custom(ResNet):\n    \"\"\"ResNet backbone for DetectoRS.\n    Args:\n        sac (dict, optional): Dictionary to construct SAC (Switchable Atrous\n            Convolution). Default: None.\n        stage_with_sac (list): Which stage to use sac. Default: (False, False,\n            False, False).\n        rfp_inplanes (int, optional): The number of channels from RFP.\n            Default: None. If specified, an additional conv layer will be\n            added for ``rfp_feat``. Otherwise, the structure is the same as\n            base class.\n        output_img (bool): If ``True``, the input image will be inserted into\n            the starting position of output. Default: False.\n    \"\"\"\n\n    arch_settings = {\n        50: (Bottleneck, (3, 4, 6, 3)),\n        101: (Bottleneck, (3, 4, 23, 3)),\n        152: (Bottleneck, (3, 8, 36, 3))\n    }\n\n    def __init__(self,\n                 sac=None,\n                 stage_with_sac=(False, False, False, False),\n                 rfp_inplanes=None,\n                 output_img=False,\n                 pretrained=None,\n                 init_cfg=None,\n                 **kwargs):\n        assert not (init_cfg and pretrained), \\\n            'init_cfg and pretrained cannot be specified at the same time'\n        assert pretrained is None, \"pretrained is not supported anymore\"\n        self.sac = sac\n        self.stage_with_sac = stage_with_sac\n        self.rfp_inplanes = rfp_inplanes\n        self.output_img = output_img\n        super().__init__(init_cfg=init_cfg, **kwargs)\n\n        self.inplanes = self.stem_channels\n        self.res_layers = []\n        for i, num_blocks in enumerate(self.stage_blocks):\n            stride = self.strides[i]\n            dilation = self.dilations[i]\n            dcn = self.dcn if self.stage_with_dcn[i] else None\n            sac = self.sac if self.stage_with_sac[i] else None\n            if self.plugins is not None:\n                stage_plugins = self.make_stage_plugins(self.plugins, i)\n            else:\n                stage_plugins = None\n            planes = self.base_channels * 2 ** i\n            res_layer = self.make_res_layer(\n                block=self.block,\n                inplanes=self.inplanes,\n                planes=planes,\n                num_blocks=num_blocks,\n                stride=stride,\n                dilation=dilation,\n                style=self.style,\n                avg_down=self.avg_down,\n                with_cp=self.with_cp,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                dcn=dcn,\n                sac=sac,\n                rfp_inplanes=rfp_inplanes if i > 0 else None,\n                plugins=stage_plugins)\n            self.inplanes = planes * self.block.expansion\n            layer_name = f'layer{i + 1}'\n            self.add_module(layer_name, res_layer)\n            self.res_layers.append(layer_name)\n\n        self._freeze_stages()\n\n    # In order to be properly initialized by RFP\n    def init_weights(self):\n        # Calling this method will cause parameter initialization exception\n        # super(DetectoRS_ResNet, self).init_weights()\n        if self.init_cfg is not None:\n            super(ResNet, self).init_weights()\n        elif self.pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                    constant_init(m, 1)\n\n            if self.dcn is not None:\n                for m in self.modules():\n                    if isinstance(m, Bottleneck) and hasattr(\n                            m.conv2, 'conv_offset'):\n                        constant_init(m.conv2.conv_offset, 0)\n\n            if self.zero_init_residual:\n                for m in self.modules():\n                    if isinstance(m, Bottleneck):\n                        constant_init(m.norm3, 0)\n                    elif isinstance(m, BasicBlock):\n                        constant_init(m.norm2, 0)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def make_res_layer(self, **kwargs):\n        \"\"\"Pack all blocks in a stage into a ``ResLayer`` for DetectoRS.\"\"\"\n        return ResLayer(**kwargs)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        outs = list(super().forward(x))\n        if self.output_img:\n            outs.insert(0, x)\n        return tuple(outs)\n\n    def rfp_forward(self, x, rfp_feats):\n        \"\"\"Forward function for RFP.\"\"\"\n        if self.deep_stem:\n            x = self.stem(x)\n        else:\n            x = self.conv1(x)\n            x = self.norm1(x)\n            x = self.relu(x)\n        x = self.maxpool(x)\n        outs = []\n        for i, layer_name in enumerate(self.res_layers):\n            res_layer = getattr(self, layer_name)\n            rfp_feat = rfp_feats[i] if i > 0 else None\n            for layer in res_layer:\n                x = layer.rfp_forward(x, rfp_feat)\n            if i in self.out_indices:\n                outs.append(x)\n        return tuple(outs)\n"
  },
  {
    "path": "swin/ckpt_convert.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\n# This script consists of several convert functions which\n# can modify the weights of model in original repo to be\n# pre-trained weights.\n\nfrom collections import OrderedDict\n\nimport torch\n\n\ndef pvt_convert(ckpt):\n    new_ckpt = OrderedDict()\n    # Process the concat between q linear weights and kv linear weights\n    use_abs_pos_embed = False\n    use_conv_ffn = False\n    for k in ckpt.keys():\n        if k.startswith('pos_embed'):\n            use_abs_pos_embed = True\n        if k.find('dwconv') >= 0:\n            use_conv_ffn = True\n    for k, v in ckpt.items():\n        if k.startswith('head'):\n            continue\n        if k.startswith('norm.'):\n            continue\n        if k.startswith('cls_token'):\n            continue\n        if k.startswith('pos_embed'):\n            stage_i = int(k.replace('pos_embed', ''))\n            new_k = k.replace(f'pos_embed{stage_i}',\n                              f'layers.{stage_i - 1}.1.0.pos_embed')\n            if stage_i == 4 and v.size(1) == 50:  # 1 (cls token) + 7 * 7\n                new_v = v[:, 1:, :]  # remove cls token\n            else:\n                new_v = v\n        elif k.startswith('patch_embed'):\n            stage_i = int(k.split('.')[0].replace('patch_embed', ''))\n            new_k = k.replace(f'patch_embed{stage_i}',\n                              f'layers.{stage_i - 1}.0')\n            new_v = v\n            if 'proj.' in new_k:\n                new_k = new_k.replace('proj.', 'projection.')\n        elif k.startswith('block'):\n            stage_i = int(k.split('.')[0].replace('block', ''))\n            layer_i = int(k.split('.')[1])\n            new_layer_i = layer_i + use_abs_pos_embed\n            new_k = k.replace(f'block{stage_i}.{layer_i}',\n                              f'layers.{stage_i - 1}.1.{new_layer_i}')\n            new_v = v\n            if 'attn.q.' in new_k:\n                sub_item_k = k.replace('q.', 'kv.')\n                new_k = new_k.replace('q.', 'attn.in_proj_')\n                new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)\n            elif 'attn.kv.' in new_k:\n                continue\n            elif 'attn.proj.' in new_k:\n                new_k = new_k.replace('proj.', 'attn.out_proj.')\n            elif 'attn.sr.' in new_k:\n                new_k = new_k.replace('sr.', 'sr.')\n            elif 'mlp.' in new_k:\n                string = f'{new_k}-'\n                new_k = new_k.replace('mlp.', 'ffn.layers.')\n                if 'fc1.weight' in new_k or 'fc2.weight' in new_k:\n                    new_v = v.reshape((*v.shape, 1, 1))\n                new_k = new_k.replace('fc1.', '0.')\n                new_k = new_k.replace('dwconv.dwconv.', '1.')\n                if use_conv_ffn:\n                    new_k = new_k.replace('fc2.', '4.')\n                else:\n                    new_k = new_k.replace('fc2.', '3.')\n                string += f'{new_k} {v.shape}-{new_v.shape}'\n        elif k.startswith('norm'):\n            stage_i = int(k[4])\n            new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2')\n            new_v = v\n        else:\n            new_k = k\n            new_v = v\n        new_ckpt[new_k] = new_v\n\n    return new_ckpt\n\n\ndef swin_converter(ckpt):\n\n    new_ckpt = OrderedDict()\n\n    def correct_unfold_reduction_order(x):\n        out_channel, in_channel = x.shape\n        x = x.reshape(out_channel, 4, in_channel // 4)\n        x = x[:, [0, 2, 1, 3], :].transpose(1,\n                                            2).reshape(out_channel, in_channel)\n        return x\n\n    def correct_unfold_norm_order(x):\n        in_channel = x.shape[0]\n        x = x.reshape(4, in_channel // 4)\n        x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)\n        return x\n\n    for k, v in ckpt.items():\n        if k.startswith('head'):\n            continue\n        elif k.startswith('layers'):\n            new_v = v\n            if 'attn.' in k:\n                new_k = k.replace('attn.', 'attn.w_msa.')\n            elif 'mlp.' in k:\n                if 'mlp.fc1.' in k:\n                    new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')\n                elif 'mlp.fc2.' in k:\n                    new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')\n                else:\n                    new_k = k.replace('mlp.', 'ffn.')\n            elif 'downsample' in k:\n                new_k = k\n                if 'reduction.' in k:\n                    new_v = correct_unfold_reduction_order(v)\n                elif 'norm.' in k:\n                    new_v = correct_unfold_norm_order(v)\n            else:\n                new_k = k\n            new_k = new_k.replace('layers', 'stages', 1)\n        elif k.startswith('patch_embed'):\n            new_v = v\n            if 'proj' in k:\n                new_k = k.replace('proj', 'projection')\n            else:\n                new_k = k\n        else:\n            new_v = v\n            new_k = k\n\n        new_ckpt[new_k] = new_v\n\n    return new_ckpt\n"
  },
  {
    "path": "swin/mix_transformer.py",
    "content": "# ---------------------------------------------------------------\n# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.\n#\n# This work is licensed under the NVIDIA Source Code License\n# ---------------------------------------------------------------\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nfrom mmdet.models.builder import BACKBONES\nfrom mmdet.utils import get_root_logger\nfrom mmdet.models.backbones.resnet import ResNet\nfrom mmcv.runner import load_checkpoint, BaseModule\nimport math\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.dwconv = DWConv(hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x, H, W):\n        x = self.fc1(x)\n        x = self.dwconv(x, H, W)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):\n        super().__init__()\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.sr_ratio = sr_ratio\n        if sr_ratio > 1:\n            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)\n            self.norm = nn.LayerNorm(dim)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        if self.sr_ratio > 1:\n            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)\n            x_ = self.norm(x_)\n            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        else:\n            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        k, v = kv[0], kv[1]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x, H, W):\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))\n\n        return x\n\n\nclass OverlapPatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dims=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(in_chans, embed_dims, kernel_size=patch_size, stride=stride,\n                              padding=(patch_size[0] // 2, patch_size[1] // 2))\n        self.norm = nn.LayerNorm(embed_dims)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        x = self.proj(x)\n        _, _, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n\n        return x, H, W\n\n\nclass MixVisionTransformer(BaseModule):\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],\n                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,\n                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], init_cfg=None):\n        super().__init__(init_cfg=init_cfg)\n        self.num_classes = num_classes\n        self.depths = depths\n\n        # patch_embed\n        self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,\n                                              embed_dims=embed_dimss[0])\n        self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dimss[0],\n                                              embed_dims=embed_dimss[1])\n        self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dimss[1],\n                                              embed_dims=embed_dimss[2])\n        self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dimss[2],\n                                              embed_dims=embed_dimss[3])\n\n        # transformer encoder\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        cur = 0\n        self.block1 = nn.ModuleList([Block(\n            dim=embed_dimss[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[0])\n            for i in range(depths[0])])\n        self.norm1 = norm_layer(embed_dimss[0])\n\n        cur += depths[0]\n        self.block2 = nn.ModuleList([Block(\n            dim=embed_dimss[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[1])\n            for i in range(depths[1])])\n        self.norm2 = norm_layer(embed_dimss[1])\n\n        cur += depths[1]\n        self.block3 = nn.ModuleList([Block(\n            dim=embed_dimss[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[2])\n            for i in range(depths[2])])\n        self.norm3 = norm_layer(embed_dimss[2])\n\n        cur += depths[2]\n        self.block4 = nn.ModuleList([Block(\n            dim=embed_dimss[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[3])\n            for i in range(depths[3])])\n        self.norm4 = norm_layer(embed_dimss[3])\n\n        # classification head\n        # self.head = nn.Linear(embed_dimss[3], num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    # def init_weights(self, pretrained=None):\n    #     if isinstance(pretrained, str):\n    #         logger = get_root_logger()\n    #         load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)\n\n    def reset_drop_path(self, drop_path_rate):\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]\n        cur = 0\n        for i in range(self.depths[0]):\n            self.block1[i].drop_path.drop_prob = dpr[cur + i]\n\n        cur += self.depths[0]\n        for i in range(self.depths[1]):\n            self.block2[i].drop_path.drop_prob = dpr[cur + i]\n\n        cur += self.depths[1]\n        for i in range(self.depths[2]):\n            self.block3[i].drop_path.drop_prob = dpr[cur + i]\n\n        cur += self.depths[2]\n        for i in range(self.depths[3]):\n            self.block4[i].drop_path.drop_prob = dpr[cur + i]\n\n    def freeze_patch_emb(self):\n        self.patch_embed1.requires_grad = False\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        outs = []\n\n        # stage 1\n        x, H, W = self.patch_embed1(x)\n        for i, blk in enumerate(self.block1):\n            x = blk(x, H, W)\n        x = self.norm1(x)\n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        outs.append(x)\n\n        # stage 2\n        x, H, W = self.patch_embed2(x)\n        for i, blk in enumerate(self.block2):\n            x = blk(x, H, W)\n        x = self.norm2(x)\n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        outs.append(x)\n\n        # stage 3\n        x, H, W = self.patch_embed3(x)\n        for i, blk in enumerate(self.block3):\n            x = blk(x, H, W)\n        x = self.norm3(x)\n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        outs.append(x)\n\n        # stage 4\n        x, H, W = self.patch_embed4(x)\n        for i, blk in enumerate(self.block4):\n            x = blk(x, H, W)\n        x = self.norm4(x)\n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        outs.append(x)\n\n        return outs\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        # x = self.head(x)\n\n        return x\n\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\n\n@BACKBONES.register_module()\nclass mit_b0(MixVisionTransformer):\n    def __init__(self, **kwargs):\n        super(mit_b0, self).__init__(\n            patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],\n            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],\n            drop_rate=0.0, drop_path_rate=0.1, **kwargs)\n\n\n@BACKBONES.register_module()\nclass mit_b1(MixVisionTransformer):\n    def __init__(self, **kwargs):\n        super(mit_b1, self).__init__(\n            patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],\n            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],\n            drop_rate=0.0, drop_path_rate=0.1, **kwargs)\n\n\n@BACKBONES.register_module()\nclass mit_b2(MixVisionTransformer):\n    def __init__(self, **kwargs):\n        super(mit_b2, self).__init__(\n            patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],\n            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],\n            drop_rate=0.0, drop_path_rate=0.1, **kwargs)\n\n\n@BACKBONES.register_module()\nclass mit_b3(MixVisionTransformer):\n    def __init__(self, **kwargs):\n        super(mit_b3, self).__init__(\n            patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],\n            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],\n            drop_rate=0.0, drop_path_rate=0.1, **kwargs)\n\n\n@BACKBONES.register_module()\nclass mit_b4(MixVisionTransformer):\n    def __init__(self, **kwargs):\n        super(mit_b4, self).__init__(\n            patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],\n            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],\n            drop_rate=0.0, drop_path_rate=0.1, **kwargs)\n\n\n@BACKBONES.register_module()\nclass mit_b5(MixVisionTransformer):\n    def __init__(self, **kwargs):\n        super(mit_b5, self).__init__(\n            patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],\n            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],\n            drop_rate=0.0, drop_path_rate=0.1, **kwargs)\n\n\n@BACKBONES.register_module()\nclass ResNetV1c(ResNet):\n    r\"\"\"ResNetV1d variant described in `Bag of Tricks\n    <https://arxiv.org/pdf/1812.01187.pdf>`_.\n\n    Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in\n    the input stem with three 3x3 convs. And in the downsampling block, a 2x2\n    avg_pool with stride 2 is added before conv, whose stride is changed to 1.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super(ResNetV1c, self).__init__(\n            deep_stem=True, avg_down=False, **kwargs)"
  },
  {
    "path": "swin/swin_checkpoint.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\nimport io\nimport os\nimport os.path as osp\nimport pkgutil\nimport time\nimport warnings\nfrom collections import OrderedDict\nfrom importlib import import_module\nfrom tempfile import TemporaryDirectory\n\nimport mmcv\nimport torch\nimport torchvision\nfrom mmcv.fileio import FileClient\nfrom mmcv.fileio import load as load_file\nfrom mmcv.parallel import is_module_wrapper\nfrom mmcv.runner import get_dist_info\nfrom mmcv.utils import mkdir_or_exist\nfrom torch.nn import functional as F\nfrom torch.optim import Optimizer\nfrom torch.utils import model_zoo\n\nENV_MMCV_HOME = 'MMCV_HOME'\nENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'\nDEFAULT_CACHE_DIR = '~/.cache'\n\n\ndef _get_mmcv_home():\n    mmcv_home = os.path.expanduser(\n        os.getenv(\n            ENV_MMCV_HOME,\n            os.path.join(\n                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))\n\n    mkdir_or_exist(mmcv_home)\n    return mmcv_home\n\n\ndef load_state_dict(module, state_dict, strict=False, logger=None):\n    \"\"\"Load state_dict to a module.\n\n    This method is modified from :meth:`torch.nn.Module.load_state_dict`.\n    Default value for ``strict`` is set to ``False`` and the message for\n    param mismatch will be shown even if strict is False.\n    Args:\n        module (Module): Module that receives the state_dict.\n        state_dict (OrderedDict): Weights.\n        strict (bool): whether to strictly enforce that the keys\n            in :attr:`state_dict` match the keys returned by this module's\n            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.\n        logger (:obj:`logging.Logger`, optional): Logger to log the error\n            message. If not specified, print function will be used.\n    \"\"\"\n    unexpected_keys = []\n    all_missing_keys = []\n    err_msg = []\n\n    metadata = getattr(state_dict, '_metadata', None)\n    state_dict = state_dict.copy()\n    if metadata is not None:\n        state_dict._metadata = metadata\n\n    # use _load_from_state_dict to enable checkpoint version control\n    def load(module, prefix=''):\n        # recursively check parallel module in case that the model has a\n        # complicated structure, e.g., nn.Module(nn.Module(DDP))\n        if is_module_wrapper(module):\n            module = module.module\n        local_metadata = {} if metadata is None else metadata.get(\n            prefix[:-1], {})\n        module._load_from_state_dict(state_dict, prefix, local_metadata, True,\n                                     all_missing_keys, unexpected_keys,\n                                     err_msg)\n        for name, child in module._modules.items():\n            if child is not None:\n                load(child, prefix + name + '.')\n\n    load(module)\n    load = None  # break load->load reference cycle\n\n    # ignore \"num_batches_tracked\" of BN layers\n    missing_keys = [\n        key for key in all_missing_keys if 'num_batches_tracked' not in key\n    ]\n\n    if unexpected_keys:\n        err_msg.append('unexpected key in source '\n                       f'state_dict: {\", \".join(unexpected_keys)}\\n')\n    if missing_keys:\n        err_msg.append(\n            f'missing keys in source state_dict: {\", \".join(missing_keys)}\\n')\n\n    rank, _ = get_dist_info()\n    if len(err_msg) > 0 and rank == 0:\n        err_msg.insert(\n            0, 'The model and loaded state dict do not match exactly\\n')\n        err_msg = '\\n'.join(err_msg)\n        if strict:\n            raise RuntimeError(err_msg)\n        elif logger is not None:\n            logger.warning(err_msg)\n        else:\n            print(err_msg)\n\n\ndef load_url_dist(url, model_dir=None):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        checkpoint = model_zoo.load_url(url, model_dir=model_dir)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            checkpoint = model_zoo.load_url(url, model_dir=model_dir)\n    return checkpoint\n\n\ndef load_pavimodel_dist(model_path, map_location=None):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    try:\n        from pavi import modelcloud\n    except ImportError:\n        raise ImportError(\n            'Please install pavi to load checkpoint from modelcloud.')\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        model = modelcloud.get(model_path)\n        with TemporaryDirectory() as tmp_dir:\n            downloaded_file = osp.join(tmp_dir, model.name)\n            model.download(downloaded_file)\n            checkpoint = torch.load(downloaded_file, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            model = modelcloud.get(model_path)\n            with TemporaryDirectory() as tmp_dir:\n                downloaded_file = osp.join(tmp_dir, model.name)\n                model.download(downloaded_file)\n                checkpoint = torch.load(\n                    downloaded_file, map_location=map_location)\n    return checkpoint\n\n\ndef load_fileclient_dist(filename, backend, map_location):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    allowed_backends = ['ceph']\n    if backend not in allowed_backends:\n        raise ValueError(f'Load from Backend {backend} is not supported.')\n    if rank == 0:\n        fileclient = FileClient(backend=backend)\n        buffer = io.BytesIO(fileclient.get(filename))\n        checkpoint = torch.load(buffer, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            fileclient = FileClient(backend=backend)\n            buffer = io.BytesIO(fileclient.get(filename))\n            checkpoint = torch.load(buffer, map_location=map_location)\n    return checkpoint\n\n\ndef get_torchvision_models():\n    model_urls = dict()\n    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):\n        if ispkg:\n            continue\n        _zoo = import_module(f'torchvision.models.{name}')\n        if hasattr(_zoo, 'model_urls'):\n            _urls = getattr(_zoo, 'model_urls')\n            model_urls.update(_urls)\n    return model_urls\n\n\ndef get_external_models():\n    mmcv_home = _get_mmcv_home()\n    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')\n    default_urls = load_file(default_json_path)\n    assert isinstance(default_urls, dict)\n    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')\n    if osp.exists(external_json_path):\n        external_urls = load_file(external_json_path)\n        assert isinstance(external_urls, dict)\n        default_urls.update(external_urls)\n\n    return default_urls\n\n\ndef get_mmcls_models():\n    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')\n    mmcls_urls = load_file(mmcls_json_path)\n\n    return mmcls_urls\n\n\ndef get_deprecated_model_names():\n    deprecate_json_path = osp.join(mmcv.__path__[0],\n                                   'model_zoo/deprecated.json')\n    deprecate_urls = load_file(deprecate_json_path)\n    assert isinstance(deprecate_urls, dict)\n\n    return deprecate_urls\n\n\ndef _process_mmcls_checkpoint(checkpoint):\n    state_dict = checkpoint['state_dict']\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items():\n        if k.startswith('backbone.'):\n            new_state_dict[k[9:]] = v\n    new_checkpoint = dict(state_dict=new_state_dict)\n\n    return new_checkpoint\n\n\ndef _load_checkpoint(filename, map_location=None):\n    \"\"\"Load checkpoint from somewhere (modelzoo, file, url).\n\n    Args:\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str | None): Same as :func:`torch.load`. Default: None.\n    Returns:\n        dict | OrderedDict: The loaded checkpoint. It can be either an\n            OrderedDict storing model weights or a dict containing other\n            information, which depends on the checkpoint.\n    \"\"\"\n    if filename.startswith('modelzoo://'):\n        warnings.warn('The URL scheme of \"modelzoo://\" is deprecated, please '\n                      'use \"torchvision://\" instead')\n        model_urls = get_torchvision_models()\n        model_name = filename[11:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('torchvision://'):\n        model_urls = get_torchvision_models()\n        model_name = filename[14:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('open-mmlab://'):\n        model_urls = get_external_models()\n        model_name = filename[13:]\n        deprecated_urls = get_deprecated_model_names()\n        if model_name in deprecated_urls:\n            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '\n                          f'of open-mmlab://{deprecated_urls[model_name]}')\n            model_name = deprecated_urls[model_name]\n        model_url = model_urls[model_name]\n        # check if is url\n        if model_url.startswith(('http://', 'https://')):\n            checkpoint = load_url_dist(model_url)\n        else:\n            filename = osp.join(_get_mmcv_home(), model_url)\n            if not osp.isfile(filename):\n                raise IOError(f'{filename} is not a checkpoint file')\n            checkpoint = torch.load(filename, map_location=map_location)\n    elif filename.startswith('mmcls://'):\n        model_urls = get_mmcls_models()\n        model_name = filename[8:]\n        checkpoint = load_url_dist(model_urls[model_name])\n        checkpoint = _process_mmcls_checkpoint(checkpoint)\n    elif filename.startswith(('http://', 'https://')):\n        checkpoint = load_url_dist(filename)\n    elif filename.startswith('pavi://'):\n        model_path = filename[7:]\n        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)\n    elif filename.startswith('s3://'):\n        checkpoint = load_fileclient_dist(\n            filename, backend='ceph', map_location=map_location)\n    else:\n        if not osp.isfile(filename):\n            raise IOError(f'{filename} is not a checkpoint file')\n        checkpoint = torch.load(filename, map_location=map_location)\n    return checkpoint\n\n\ndef load_checkpoint(model,\n                    filename,\n                    map_location='cpu',\n                    strict=False,\n                    logger=None):\n    \"\"\"Load checkpoint from a file or URI.\n\n    Args:\n        model (Module): Module to load checkpoint.\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str): Same as :func:`torch.load`.\n        strict (bool): Whether to allow different params for the model and\n            checkpoint.\n        logger (:mod:`logging.Logger` or None): The logger for error message.\n    Returns:\n        dict or OrderedDict: The loaded checkpoint.\n    \"\"\"\n    checkpoint = _load_checkpoint(filename, map_location)\n    # OrderedDict is a subclass of dict\n    if not isinstance(checkpoint, dict):\n        raise RuntimeError(\n            f'No state_dict found in checkpoint file {filename}')\n    # get state_dict from checkpoint\n    if 'state_dict' in checkpoint:\n        state_dict = checkpoint['state_dict']\n    elif 'model' in checkpoint:\n        state_dict = checkpoint['model']\n    else:\n        state_dict = checkpoint\n    # strip prefix of state_dict\n    if list(state_dict.keys())[0].startswith('module.'):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n    # reshape absolute position embedding\n    if state_dict.get('absolute_pos_embed') is not None:\n        absolute_pos_embed = state_dict['absolute_pos_embed']\n        N1, L, C1 = absolute_pos_embed.size()\n        N2, C2, H, W = model.absolute_pos_embed.size()\n        if N1 != N2 or C1 != C2 or L != H * W:\n            logger.warning('Error in loading absolute_pos_embed, pass')\n        else:\n            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n                N2, H, W, C2).permute(0, 3, 1, 2)\n\n    # interpolate position bias table if needed\n    relative_position_bias_table_keys = [\n        k for k in state_dict.keys() if 'relative_position_bias_table' in k\n    ]\n    for table_key in relative_position_bias_table_keys:\n        table_pretrained = state_dict[table_key]\n        table_current = model.state_dict()[table_key]\n        L1, nH1 = table_pretrained.size()\n        L2, nH2 = table_current.size()\n        if nH1 != nH2:\n            logger.warning(f'Error in loading {table_key}, pass')\n        else:\n            if L1 != L2:\n                S1 = int(L1**0.5)\n                S2 = int(L2**0.5)\n                table_pretrained_resized = F.interpolate(\n                    table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                    size=(S2, S2),\n                    mode='bicubic')\n                state_dict[table_key] = table_pretrained_resized.view(\n                    nH2, L2).permute(1, 0)\n\n    # load state_dict\n    load_state_dict(model, state_dict, strict, logger)\n    return checkpoint\n\n\ndef weights_to_cpu(state_dict):\n    \"\"\"Copy a model state_dict to cpu.\n\n    Args:\n        state_dict (OrderedDict): Model weights on GPU.\n    Returns:\n        OrderedDict: Model weights on GPU.\n    \"\"\"\n    state_dict_cpu = OrderedDict()\n    for key, val in state_dict.items():\n        state_dict_cpu[key] = val.cpu()\n    return state_dict_cpu\n\n\ndef _save_to_state_dict(module, destination, prefix, keep_vars):\n    \"\"\"Saves module state to `destination` dictionary.\n\n    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (dict): A dict where state will be stored.\n        prefix (str): The prefix for parameters and buffers used in this\n            module.\n    \"\"\"\n    for name, param in module._parameters.items():\n        if param is not None:\n            destination[prefix + name] = param if keep_vars else param.detach()\n    for name, buf in module._buffers.items():\n        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d\n        if buf is not None:\n            destination[prefix + name] = buf if keep_vars else buf.detach()\n\n\ndef get_state_dict(module, destination=None, prefix='', keep_vars=False):\n    \"\"\"Returns a dictionary containing a whole state of the module.\n\n    Both parameters and persistent buffers (e.g. running averages) are\n    included. Keys are corresponding parameter and buffer names.\n    This method is modified from :meth:`torch.nn.Module.state_dict` to\n    recursively check parallel module in case that the model has a complicated\n    structure, e.g., nn.Module(nn.Module(DDP)).\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (OrderedDict): Returned dict for the state of the\n            module.\n        prefix (str): Prefix of the key.\n        keep_vars (bool): Whether to keep the variable property of the\n            parameters. Default: False.\n    Returns:\n        dict: A dictionary containing a whole state of the module.\n    \"\"\"\n    # recursively check parallel module in case that the model has a\n    # complicated structure, e.g., nn.Module(nn.Module(DDP))\n    if is_module_wrapper(module):\n        module = module.module\n\n    # below is the same as torch.nn.Module.state_dict()\n    if destination is None:\n        destination = OrderedDict()\n        destination._metadata = OrderedDict()\n    destination._metadata[prefix[:-1]] = local_metadata = dict(\n        version=module._version)\n    _save_to_state_dict(module, destination, prefix, keep_vars)\n    for name, child in module._modules.items():\n        if child is not None:\n            get_state_dict(\n                child, destination, prefix + name + '.', keep_vars=keep_vars)\n    for hook in module._state_dict_hooks.values():\n        hook_result = hook(module, destination, prefix, local_metadata)\n        if hook_result is not None:\n            destination = hook_result\n    return destination\n\n\ndef save_checkpoint(model, filename, optimizer=None, meta=None):\n    \"\"\"Save checkpoint to file.\n\n    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and\n    ``optimizer``. By default ``meta`` will contain version and time info.\n    Args:\n        model (Module): Module whose params are to be saved.\n        filename (str): Checkpoint filename.\n        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.\n        meta (dict, optional): Metadata to be saved in checkpoint.\n    \"\"\"\n    if meta is None:\n        meta = {}\n    elif not isinstance(meta, dict):\n        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')\n    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())\n\n    if is_module_wrapper(model):\n        model = model.module\n\n    if hasattr(model, 'CLASSES') and model.CLASSES is not None:\n        # save class name to the meta\n        meta.update(CLASSES=model.CLASSES)\n\n    checkpoint = {\n        'meta': meta,\n        'state_dict': weights_to_cpu(get_state_dict(model))\n    }\n    # save optimizer state dict in the checkpoint\n    if isinstance(optimizer, Optimizer):\n        checkpoint['optimizer'] = optimizer.state_dict()\n    elif isinstance(optimizer, dict):\n        checkpoint['optimizer'] = {}\n        for name, optim in optimizer.items():\n            checkpoint['optimizer'][name] = optim.state_dict()\n\n    if filename.startswith('pavi://'):\n        try:\n            from pavi import modelcloud\n            from pavi.exception import NodeNotFoundError\n        except ImportError:\n            raise ImportError(\n                'Please install pavi to load checkpoint from modelcloud.')\n        model_path = filename[7:]\n        root = modelcloud.Folder()\n        model_dir, model_name = osp.split(model_path)\n        try:\n            model = modelcloud.get(model_dir)\n        except NodeNotFoundError:\n            model = root.create_training_model(model_dir)\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint_file = osp.join(tmp_dir, model_name)\n            with open(checkpoint_file, 'wb') as f:\n                torch.save(checkpoint, f)\n                f.flush()\n            model.create_file(checkpoint_file, name=model_name)\n    else:\n        mmcv.mkdir_or_exist(osp.dirname(filename))\n        # immediately flush buffer\n        with open(filename, 'wb') as f:\n            torch.save(checkpoint, f)\n            f.flush()\n"
  },
  {
    "path": "swin/swin_transformer.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu, Yutong Lin, Yixuan Wei\n# --------------------------------------------------------\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom mmdet.models.builder import BACKBONES\nfrom mmdet.utils import get_root_logger\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom .swin_checkpoint import load_checkpoint\n\n\nclass Mlp(nn.Module):\n    \"\"\"Multilayer perceptron.\"\"\"\n\n    def __init__(self,\n                 in_features,\n                 hidden_features=None,\n                 out_features=None,\n                 act_layer=nn.GELU,\n                 drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size,\n               C)\n    windows = x.permute(0, 1, 3, 2, 4,\n                        5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size,\n                     window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    \"\"\"Window based multi-head self attention (W-MSA) module with relative\n    position bias.\n\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 window_size,\n                 num_heads,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 attn_drop=0.,\n                 proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),\n                        num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :,\n                                         None] - coords_flatten[:,\n                                                                None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(\n            1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :,\n                        0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer('relative_position_index',\n                             relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"Forward function.\n\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,\n                                  C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[\n            2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1],\n                self.window_size[0] * self.window_size[1],\n                -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N,\n                             N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    \"\"\"Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 num_heads,\n                 window_size=7,\n                 shift_size=0,\n                 mlp_ratio=4.,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop=0.,\n                 attn_drop=0.,\n                 drop_path=0.,\n                 act_layer=nn.GELU,\n                 norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=to_2tuple(self.window_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop)\n\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop)\n\n        self.H = None\n        self.W = None\n\n    def forward(self, x, mask_matrix):\n        \"\"\"Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n            mask_matrix: Attention mask for cyclic shift.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, 'input feature has wrong size'\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # pad feature maps to multiples of window size\n        pad_l = pad_t = 0\n        pad_r = (self.window_size - W % self.window_size) % self.window_size\n        pad_b = (self.window_size - H % self.window_size) % self.window_size\n        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))\n        _, Hp, Wp, _ = x.shape\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(\n                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n            attn_mask = mask_matrix\n        else:\n            shifted_x = x\n            attn_mask = None\n\n        # partition windows\n        x_windows = window_partition(\n            shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size,\n                                   C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(\n            x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size,\n                                         self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, Hp,\n                                   Wp)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(\n                shifted_x,\n                shifts=(self.shift_size, self.shift_size),\n                dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if pad_r > 0 or pad_b > 0:\n            x = x[:, :H, :W, :].contiguous()\n\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass PatchMerging(nn.Module):\n    \"\"\" Patch Merging Layer\n    Args:\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        assert L == H * W, 'input feature has wrong size'\n\n        x = x.view(B, H, W, C)\n\n        # padding\n        pad_input = (H % 2 == 1) or (W % 2 == 1)\n        if pad_input:\n            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n\nclass BasicLayer(nn.Module):\n    \"\"\"A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        num_heads (int): Number of attention head.\n        window_size (int): Local window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        with_cp (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 depth,\n                 num_heads,\n                 window_size=7,\n                 mlp_ratio=4.,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop=0.,\n                 attn_drop=0.,\n                 drop_path=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 with_cp=False):\n        super().__init__()\n        self.window_size = window_size\n        self.shift_size = window_size // 2\n        self.depth = depth\n        self.with_cp = with_cp\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(\n                dim=dim,\n                num_heads=num_heads,\n                window_size=window_size,\n                shift_size=0 if (i % 2 == 0) else window_size // 2,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop,\n                attn_drop=attn_drop,\n                drop_path=drop_path[i]\n                if isinstance(drop_path, list) else drop_path,\n                norm_layer=norm_layer) for i in range(depth)\n        ])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n\n        # calculate attention mask for SW-MSA\n        Hp = int(np.ceil(H / self.window_size)) * self.window_size\n        Wp = int(np.ceil(W / self.window_size)) * self.window_size\n        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1\n        h_slices = (slice(0, -self.window_size),\n                    slice(-self.window_size,\n                          -self.shift_size), slice(-self.shift_size, None))\n        w_slices = (slice(0, -self.window_size),\n                    slice(-self.window_size,\n                          -self.shift_size), slice(-self.shift_size, None))\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n        mask_windows = window_partition(\n            img_mask, self.window_size)  # nW, window_size, window_size, 1\n        mask_windows = mask_windows.view(-1,\n                                         self.window_size * self.window_size)\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n        attn_mask = attn_mask.masked_fill(attn_mask != 0,\n                                          float(-100.0)).masked_fill(\n                                              attn_mask == 0, float(0.0))\n        attn_mask = attn_mask.to(dtype=x.dtype)\n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.with_cp:\n                x = checkpoint.checkpoint(blk, x, attn_mask)\n            else:\n                x = blk(x, attn_mask)\n        if self.downsample is not None:\n            x_down = self.downsample(x, H, W)\n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dims (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self,\n                 patch_size=4,\n                 in_chans=3,\n                 embed_dims=96,\n                 norm_layer=None):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dims = embed_dims\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dims, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dims)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        # padding\n        _, _, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x,\n                      (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        x = self.proj(x)  # B C Wh Ww\n        if self.norm is not None:\n            Wh, Ww = x.size(2), x.size(3)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dims, Wh, Ww)\n\n        return x\n\n\n# @BACKBONES_Seg.register_module()\n@BACKBONES.register_module()\nclass SwinTransformerDIY(nn.Module):\n    \"\"\" Swin Transformer backbone.\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dims (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        num_heads (tuple[int]): Number of attention head of each stage.\n        window_size (int): Window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n        drop_rate (float): Dropout rate.\n        attn_drop_rate (float): Attention dropout rate. Default: 0.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        with_cp (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 pretrain_img_size=224,\n                 patch_size=4,\n                 in_chans=3,\n                 embed_dims=96,\n                 depths=[2, 2, 6, 2],\n                 num_heads=[3, 6, 12, 24],\n                 window_size=7,\n                 mlp_ratio=4.,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.2,\n                 norm_layer=nn.LayerNorm,\n                 use_abs_pos_embed=False,\n                 patch_norm=True,\n                 out_indices=(0, 1, 2, 3),\n                 frozen_stages=-1,\n                 with_cp=False,\n                 output_img=False,\n                 pretrained=None):\n        super().__init__()\n        self.output_img = output_img\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dims = embed_dims\n        self.ape = use_abs_pos_embed\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n        self.pretrained = pretrained\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dims=embed_dims,\n            norm_layer=norm_layer if self.patch_norm else None)\n\n        # absolute position embedding\n        if self.ape:\n            pretrain_img_size = to_2tuple(pretrain_img_size)\n            patch_size = to_2tuple(patch_size)\n            patches_resolution = [\n                pretrain_img_size[0] // patch_size[0],\n                pretrain_img_size[1] // patch_size[1]\n            ]\n\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, embed_dims, patches_resolution[0],\n                            patches_resolution[1]))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dims * 2**i_layer),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if\n                (i_layer < self.num_layers - 1) else None,\n                with_cp=with_cp)\n            self.layers.append(layer)\n\n        num_features = [int(embed_dims * 2**i) for i in range(self.num_layers)]\n        self.num_features = num_features\n\n        # add a norm layer for each output\n        for i_layer in out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f'norm{i_layer}'\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 1 and self.ape:\n            self.absolute_pos_embed.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        if pretrained is None and self.pretrained is not None:\n            pretrained = self.pretrained\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        x_idty = x\n        x = self.patch_embed(x)\n\n        Wh, Ww = x.size(2), x.size(3)\n        if self.ape:\n            # interpolate the position embedding to the corresponding size\n            absolute_pos_embed = F.interpolate(\n                self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')\n            x = (x + absolute_pos_embed).flatten(2).transpose(1,\n                                                              2)  # B Wh*Ww C\n        else:\n            x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = []\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W,\n                                 self.num_features[i]).permute(0, 3, 1,\n                                                               2).contiguous()\n                outs.append(out)\n\n        if self.output_img:\n            outs.insert(0, x_idty)\n        return tuple(outs)\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super().train(mode)\n        self._freeze_stages()\n"
  },
  {
    "path": "swin/swin_transformer_rfp.py",
    "content": "import warnings\nfrom collections import OrderedDict\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as cp\nfrom mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init, build_conv_layer\nfrom mmcv.cnn.bricks.transformer import FFN, build_dropout\nfrom mmcv.runner import BaseModule, ModuleList, _load_checkpoint\nfrom mmcv.utils import to_2tuple\n\nfrom mmdet.utils import get_root_logger\nfrom mmdet.models.builder import BACKBONES\nfrom .ckpt_convert import swin_converter\nfrom .transformer import PatchEmbed, PatchMerging\n\n\nclass WindowMSA(BaseModule):\n    \"\"\"Window based multi-head self-attention (W-MSA) module with relative\n    position bias.\n    Args:\n        embed_dims (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        window_size (tuple[int]): The height and width of the window.\n        qkv_bias (bool, optional):  If True, add a learnable bias to q, k, v.\n            Default: True.\n        qk_scale (float | None, optional): Override default qk scale of\n            head_dim ** -0.5 if set. Default: None.\n        attn_drop_rate (float, optional): Dropout ratio of attention weight.\n            Default: 0.0\n        proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.\n        init_cfg (dict | None, optional): The Config for initialization.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 embed_dims,\n                 num_heads,\n                 window_size,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 attn_drop_rate=0.,\n                 proj_drop_rate=0.,\n                 init_cfg=None):\n        super().__init__()\n        self.embed_dims = embed_dims\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_embed_dims = embed_dims // num_heads\n        self.scale = qk_scale or head_embed_dims ** -0.5\n        self.init_cfg = init_cfg\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),\n                        num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # About 2x faster than original impl\n        Wh, Ww = self.window_size\n        rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)\n        rel_position_index = rel_index_coords + rel_index_coords.T\n        rel_position_index = rel_position_index.flip(1).contiguous()\n        self.register_buffer('relative_position_index', rel_position_index)\n\n        self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop_rate)\n        self.proj = nn.Linear(embed_dims, embed_dims)\n        self.proj_drop = nn.Dropout(proj_drop_rate)\n\n        self.softmax = nn.Softmax(dim=-1)\n\n    def init_weights(self):\n        trunc_normal_init(self.relative_position_bias_table, std=0.02)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x (tensor): input features with shape of (num_windows*B, N, C)\n            mask (tensor | None, Optional): mask with shape of (num_windows,\n                Wh*Ww, Wh*Ww), value should be between (-inf, 0].\n        \"\"\"\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,\n                                  C // self.num_heads).permute(2, 0, 3, 1, 4)\n        # make torchscript happy (cannot use tensor as tuple)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1],\n            self.window_size[0] * self.window_size[1],\n            -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B // nW, nW, self.num_heads, N,\n                             N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n        attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    @staticmethod\n    def double_step_seq(step1, len1, step2, len2):\n        seq1 = torch.arange(0, step1 * len1, step1)\n        seq2 = torch.arange(0, step2 * len2, step2)\n        return (seq1[:, None] + seq2[None, :]).reshape(1, -1)\n\n\nclass ShiftWindowMSA(BaseModule):\n    \"\"\"Shifted Window Multihead Self-Attention Module.\n    Args:\n        embed_dims (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        window_size (int): The height and width of the window.\n        shift_size (int, optional): The shift step of each window towards\n            right-bottom. If zero, act as regular window-msa. Defaults to 0.\n        qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.\n            Default: True\n        qk_scale (float | None, optional): Override default qk scale of\n            head_dim ** -0.5 if set. Defaults: None.\n        attn_drop_rate (float, optional): Dropout ratio of attention weight.\n            Defaults: 0.\n        proj_drop_rate (float, optional): Dropout ratio of output.\n            Defaults: 0.\n        dropout_layer (dict, optional): The dropout_layer used before output.\n            Defaults: dict(type='DropPath', drop_prob=0.).\n        init_cfg (dict, optional): The extra config for initialization.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 embed_dims,\n                 num_heads,\n                 window_size,\n                 shift_size=0,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 attn_drop_rate=0,\n                 proj_drop_rate=0,\n                 dropout_layer=dict(type='DropPath', drop_prob=0.),\n                 init_cfg=None):\n        super().__init__(init_cfg)\n\n        self.window_size = window_size\n        self.shift_size = shift_size\n        assert 0 <= self.shift_size < self.window_size\n\n        self.w_msa = WindowMSA(\n            embed_dims=embed_dims,\n            num_heads=num_heads,\n            window_size=to_2tuple(window_size),\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop_rate=attn_drop_rate,\n            proj_drop_rate=proj_drop_rate,\n            init_cfg=None)\n\n        self.drop = build_dropout(dropout_layer)\n\n    def forward(self, query, hw_shape):\n        B, L, C = query.shape\n        H, W = hw_shape\n        assert L == H * W, 'input feature has wrong size'\n        query = query.view(B, H, W, C)\n\n        # pad feature maps to multiples of window size\n        pad_r = (self.window_size - W % self.window_size) % self.window_size\n        pad_b = (self.window_size - H % self.window_size) % self.window_size\n        query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))\n        H_pad, W_pad = query.shape[1], query.shape[2]\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_query = torch.roll(\n                query,\n                shifts=(-self.shift_size, -self.shift_size),\n                dims=(1, 2))\n\n            # calculate attention mask for SW-MSA\n            img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size,\n                              -self.shift_size), slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size,\n                              -self.shift_size), slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            # nW, window_size, window_size, 1\n            mask_windows = self.window_partition(img_mask)\n            mask_windows = mask_windows.view(\n                -1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0,\n                                              float(-100.0)).masked_fill(\n                attn_mask == 0, float(0.0))\n        else:\n            shifted_query = query\n            attn_mask = None\n\n        # nW*B, window_size, window_size, C\n        query_windows = self.window_partition(shifted_query)\n        # nW*B, window_size*window_size, C\n        query_windows = query_windows.view(-1, self.window_size ** 2, C)\n\n        # W-MSA/SW-MSA (nW*B, window_size*window_size, C)\n        attn_windows = self.w_msa(query_windows, mask=attn_mask)\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size,\n                                         self.window_size, C)\n\n        # B H' W' C\n        shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(\n                shifted_x,\n                shifts=(self.shift_size, self.shift_size),\n                dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if pad_r > 0 or pad_b:\n            x = x[:, :H, :W, :].contiguous()\n\n        x = x.view(B, H * W, C)\n\n        x = self.drop(x)\n        return x\n\n    def window_reverse(self, windows, H, W):\n        \"\"\"\n        Args:\n            windows: (num_windows*B, window_size, window_size, C)\n            H (int): Height of image\n            W (int): Width of image\n        Returns:\n            x: (B, H, W, C)\n        \"\"\"\n        window_size = self.window_size\n        B = int(windows.shape[0] / (H * W / window_size / window_size))\n        x = windows.view(B, H // window_size, W // window_size, window_size,\n                         window_size, -1)\n        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n        return x\n\n    def window_partition(self, x):\n        \"\"\"\n        Args:\n            x: (B, H, W, C)\n        Returns:\n            windows: (num_windows*B, window_size, window_size, C)\n        \"\"\"\n        B, H, W, C = x.shape\n        window_size = self.window_size\n        x = x.view(B, H // window_size, window_size, W // window_size,\n                   window_size, C)\n        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()\n        windows = windows.view(-1, window_size, window_size, C)\n        return windows\n\n\nclass SwinBlock(BaseModule):\n    \"\"\"\"\n    Args:\n        embed_dims (int): The feature dimension.\n        num_heads (int): Parallel attention heads.\n        feedforward_channels (int): The hidden dimension for FFNs.\n        window_size (int, optional): The local window scale. Default: 7.\n        shift (bool, optional): whether to shift window or not. Default False.\n        qkv_bias (bool, optional): enable bias for qkv if True. Default: True.\n        qk_scale (float | None, optional): Override default qk scale of\n            head_dim ** -0.5 if set. Default: None.\n        drop_rate (float, optional): Dropout rate. Default: 0.\n        attn_drop_rate (float, optional): Attention dropout rate. Default: 0.\n        drop_path_rate (float, optional): Stochastic depth rate. Default: 0.\n        act_cfg (dict, optional): The config dict of activation function.\n            Default: dict(type='GELU').\n        norm_cfg (dict, optional): The config dict of normalization.\n            Default: dict(type='LN').\n        with_cp (bool, optional): Use checkpoint or not. Using checkpoint\n            will save some memory while slowing down the training speed.\n            Default: False.\n        init_cfg (dict | list | None, optional): The init config.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 embed_dims,\n                 num_heads,\n                 feedforward_channels,\n                 window_size=7,\n                 shift=False,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.,\n                 act_cfg=dict(type='GELU'),\n                 norm_cfg=dict(type='LN'),\n                 with_cp=False,\n                 init_cfg=None):\n\n        super(SwinBlock, self).__init__()\n\n        self.init_cfg = init_cfg\n        self.with_cp = with_cp\n\n        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]\n        self.attn = ShiftWindowMSA(\n            embed_dims=embed_dims,\n            num_heads=num_heads,\n            window_size=window_size,\n            shift_size=window_size // 2 if shift else 0,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop_rate=attn_drop_rate,\n            proj_drop_rate=drop_rate,\n            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),\n            init_cfg=None)\n\n        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]\n        self.ffn = FFN(\n            embed_dims=embed_dims,\n            feedforward_channels=feedforward_channels,\n            num_fcs=2,\n            ffn_drop=drop_rate,\n            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),\n            act_cfg=act_cfg,\n            add_identity=True,\n            init_cfg=None)\n\n    def forward(self, x, hw_shape):\n\n        def _inner_forward(x):\n            identity = x\n            x = self.norm1(x)\n            x = self.attn(x, hw_shape)\n\n            x = x + identity\n\n            identity = x\n            x = self.norm2(x)\n            x = self.ffn(x, identity=identity)\n\n            return x\n\n        if self.with_cp and x.requires_grad:\n            x = cp.checkpoint(_inner_forward, x)\n        else:\n            x = _inner_forward(x)\n\n        return x\n\n\nclass SwinBlockSequence(BaseModule):\n    \"\"\"Implements one stage in Swin Transformer.\n    Args:\n        embed_dims (int): The feature dimension.\n        num_heads (int): Parallel attention heads.\n        feedforward_channels (int): The hidden dimension for FFNs.\n        depth (int): The number of blocks in this stage.\n        window_size (int, optional): The local window scale. Default: 7.\n        qkv_bias (bool, optional): enable bias for qkv if True. Default: True.\n        qk_scale (float | None, optional): Override default qk scale of\n            head_dim ** -0.5 if set. Default: None.\n        drop_rate (float, optional): Dropout rate. Default: 0.\n        attn_drop_rate (float, optional): Attention dropout rate. Default: 0.\n        drop_path_rate (float | list[float], optional): Stochastic depth\n            rate. Default: 0.\n        downsample (BaseModule | None, optional): The downsample operation\n            module. Default: None.\n        act_cfg (dict, optional): The config dict of activation function.\n            Default: dict(type='GELU').\n        norm_cfg (dict, optional): The config dict of normalization.\n            Default: dict(type='LN').\n        with_cp (bool, optional): Use checkpoint or not. Using checkpoint\n            will save some memory while slowing down the training speed.\n            Default: False.\n        init_cfg (dict | list | None, optional): The init config.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 embed_dims,\n                 num_heads,\n                 feedforward_channels,\n                 depth,\n                 window_size=7,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.,\n                 downsample=None,\n                 act_cfg=dict(type='GELU'),\n                 norm_cfg=dict(type='LN'),\n                 with_cp=False,\n                 init_cfg=None):\n        super().__init__(init_cfg=init_cfg)\n\n        if isinstance(drop_path_rate, list):\n            drop_path_rates = drop_path_rate\n            assert len(drop_path_rates) == depth\n        else:\n            drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]\n\n        self.blocks = ModuleList()\n        for i in range(depth):\n            block = SwinBlock(\n                embed_dims=embed_dims,\n                num_heads=num_heads,\n                feedforward_channels=feedforward_channels,\n                window_size=window_size,\n                shift=False if i % 2 == 0 else True,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop_rate=drop_rate,\n                attn_drop_rate=attn_drop_rate,\n                drop_path_rate=drop_path_rates[i],\n                act_cfg=act_cfg,\n                norm_cfg=norm_cfg,\n                with_cp=with_cp,\n                init_cfg=None)\n            self.blocks.append(block)\n\n        self.downsample = downsample\n\n    def forward(self, x, hw_shape):\n        for block in self.blocks:\n            x = block(x, hw_shape)\n\n        if self.downsample:\n            x_down, down_hw_shape = self.downsample(x, hw_shape)\n            return x_down, down_hw_shape, x, hw_shape\n        else:\n            return x, hw_shape, x, hw_shape\n\n\nclass SwinTransformer(BaseModule):\n    \"\"\" Swin Transformer\n    A PyTorch implement of : `Swin Transformer:\n    Hierarchical Vision Transformer using Shifted Windows`  -\n        https://arxiv.org/abs/2103.14030\n    Inspiration from\n    https://github.com/microsoft/Swin-Transformer\n    Args:\n        pretrain_img_size (int | tuple[int]): The size of input image when\n            pretrain. Defaults: 224.\n        in_channels (int): The num of input channels.\n            Defaults: 3.\n        embed_dims (int): The feature dimension. Default: 96.\n        patch_size (int | tuple[int]): Patch size. Default: 4.\n        window_size (int): Window size. Default: 7.\n        mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.\n            Default: 4.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n            Default: (2, 2, 6, 2).\n        num_heads (tuple[int]): Parallel attention heads of each Swin\n            Transformer stage. Default: (3, 6, 12, 24).\n        strides (tuple[int]): The patch merging or patch embedding stride of\n            each Swin Transformer stage. (In swin, we set kernel size equal to\n            stride.) Default: (4, 2, 2, 2).\n        out_indices (tuple[int]): Output from which stages.\n            Default: (0, 1, 2, 3).\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key,\n            value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of\n            head_dim ** -0.5 if set. Default: None.\n        patch_norm (bool): If add a norm layer for patch embed and patch\n            merging. Default: True.\n        drop_rate (float): Dropout rate. Defaults: 0.\n        attn_drop_rate (float): Attention dropout rate. Default: 0.\n        drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.\n        use_abs_pos_embed (bool): If True, add absolute position embedding to\n            the patch embedding. Defaults: False.\n        act_cfg (dict): Config dict for activation layer.\n            Default: dict(type='LN').\n        norm_cfg (dict): Config dict for normalization layer at\n            output of backone. Defaults: dict(type='LN').\n        with_cp (bool, optional): Use checkpoint or not. Using checkpoint\n            will save some memory while slowing down the training speed.\n            Default: False.\n        pretrained (str, optional): model pretrained path. Default: None.\n        convert_weights (bool): The flag indicates whether the\n            pre-trained model is from the original repo. We may need\n            to convert some keys to make it compatible.\n            Default: False.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        init_cfg (dict, optional): The Config for initialization.\n            Defaults to None.\n    \"\"\"\n\n    def __init__(self,\n                 pretrain_img_size=224,\n                 in_channels=3,\n                 embed_dims=96,\n                 patch_size=4,\n                 window_size=7,\n                 mlp_ratio=4,\n                 depths=(2, 2, 6, 2),\n                 num_heads=(3, 6, 12, 24),\n                 strides=(4, 2, 2, 2),\n                 out_indices=(0, 1, 2, 3),\n                 qkv_bias=True,\n                 qk_scale=None,\n                 patch_norm=True,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.1,\n                 use_abs_pos_embed=False,\n                 act_cfg=dict(type='GELU'),\n                 norm_cfg=dict(type='LN'),\n                 with_cp=False,\n                 pretrained=None,\n                 convert_weights=False,\n                 frozen_stages=-1,\n                 init_cfg=None):\n        self.convert_weights = convert_weights\n        self.frozen_stages = frozen_stages\n        if isinstance(pretrain_img_size, int):\n            pretrain_img_size = to_2tuple(pretrain_img_size)\n        elif isinstance(pretrain_img_size, tuple):\n            if len(pretrain_img_size) == 1:\n                pretrain_img_size = to_2tuple(pretrain_img_size[0])\n            assert len(pretrain_img_size) == 2, \\\n                f'The size of image should have length 1 or 2, ' \\\n                f'but got {len(pretrain_img_size)}'\n\n        assert not (init_cfg and pretrained), \\\n            'init_cfg and pretrained cannot be specified at the same time'\n        if isinstance(pretrained, str):\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)\n        elif pretrained is None:\n            self.init_cfg = init_cfg\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n        super(SwinTransformer, self).__init__(init_cfg=init_cfg)\n\n        num_layers = len(depths)\n        self.out_indices = out_indices\n        self.use_abs_pos_embed = use_abs_pos_embed\n\n        assert strides[0] == patch_size, 'Use non-overlapping patch embed.'\n\n        self.patch_embed = PatchEmbed(\n            in_channels=in_channels,\n            embed_dims=embed_dims,\n            conv_type='Conv2d',\n            kernel_size=patch_size,\n            stride=strides[0],\n            norm_cfg=norm_cfg if patch_norm else None,\n            init_cfg=None)\n\n        if self.use_abs_pos_embed:\n            patch_row = pretrain_img_size[0] // patch_size\n            patch_col = pretrain_img_size[1] // patch_size\n            num_patches = patch_row * patch_col\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros((1, num_patches, embed_dims)))\n\n        self.drop_after_pos = nn.Dropout(p=drop_rate)\n\n        # set stochastic depth decay rule\n        total_depth = sum(depths)\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, total_depth)\n        ]\n\n        self.stages = ModuleList()\n        in_channels = embed_dims\n        for i in range(num_layers):\n            if i < num_layers - 1:\n                downsample = PatchMerging(\n                    in_channels=in_channels,\n                    out_channels=2 * in_channels,\n                    stride=strides[i + 1],\n                    norm_cfg=norm_cfg if patch_norm else None,\n                    init_cfg=None)\n            else:\n                downsample = None\n\n            stage = SwinBlockSequence(\n                embed_dims=in_channels,\n                num_heads=num_heads[i],\n                feedforward_channels=mlp_ratio * in_channels,\n                depth=depths[i],\n                window_size=window_size,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop_rate=drop_rate,\n                attn_drop_rate=attn_drop_rate,\n                drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],\n                downsample=downsample,\n                act_cfg=act_cfg,\n                norm_cfg=norm_cfg,\n                with_cp=with_cp,\n                init_cfg=None)\n            self.stages.append(stage)\n            if downsample:\n                in_channels = downsample.out_channels\n\n        self.num_features = [int(embed_dims * 2 ** i) for i in range(num_layers)]\n        # Add a norm layer for each output\n        for i in out_indices:\n            layer = build_norm_layer(norm_cfg, self.num_features[i])[1]\n            layer_name = f'norm{i}'\n            self.add_module(layer_name, layer)\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(SwinTransformer, self).train(mode)\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n            if self.use_abs_pos_embed:\n                self.absolute_pos_embed.requires_grad = False\n            self.drop_after_pos.eval()\n\n        for i in range(1, self.frozen_stages + 1):\n\n            if (i - 1) in self.out_indices:\n                norm_layer = getattr(self, f'norm{i - 1}')\n                norm_layer.eval()\n                for param in norm_layer.parameters():\n                    param.requires_grad = False\n\n            m = self.stages[i - 1]\n            m.eval()\n            for param in m.parameters():\n                param.requires_grad = False\n\n    def init_weights(self):\n        logger = get_root_logger()\n        if self.init_cfg is None:\n            logger.warn(f'No pre-trained weights for '\n                        f'{self.__class__.__name__}, '\n                        f'training start from scratch')\n            if self.use_abs_pos_embed:\n                trunc_normal_init(self.absolute_pos_embed, std=0.02)\n            for m in self.modules():\n                if isinstance(m, nn.Linear):\n                    trunc_normal_init(m.weight, std=.02)\n                    if m.bias is not None:\n                        constant_init(m.bias, 0)\n                elif isinstance(m, nn.LayerNorm):\n                    constant_init(m.bias, 0)\n                    constant_init(m.weight, 1.0)\n        else:\n            assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n                                                  f'specify `Pretrained` in ' \\\n                                                  f'`init_cfg` in ' \\\n                                                  f'{self.__class__.__name__} '\n            ckpt = _load_checkpoint(\n                self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n            if 'state_dict' in ckpt:\n                _state_dict = ckpt['state_dict']\n            elif 'model' in ckpt:\n                _state_dict = ckpt['model']\n            else:\n                _state_dict = ckpt\n\n            state_dict = OrderedDict()\n            for k, v in _state_dict.items():\n                if k.startswith('backbone.'):\n                    state_dict[k[9:]] = v\n\n            if self.convert_weights:\n                # supported loading weight from original repo,\n                state_dict = swin_converter(state_dict)\n\n            # strip prefix of state_dict\n            if list(state_dict.keys())[0].startswith('module.'):\n                state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n            # reshape absolute position embedding\n            if state_dict.get('absolute_pos_embed') is not None:\n                absolute_pos_embed = state_dict['absolute_pos_embed']\n                N1, L, C1 = absolute_pos_embed.size()\n                N2, C2, H, W = self.absolute_pos_embed.size()\n                if N1 != N2 or C1 != C2 or L != H * W:\n                    logger.warning('Error in loading absolute_pos_embed, pass')\n                else:\n                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n            # interpolate position bias table if needed\n            relative_position_bias_table_keys = [\n                k for k in state_dict.keys()\n                if 'relative_position_bias_table' in k\n            ]\n            for table_key in relative_position_bias_table_keys:\n                table_pretrained = state_dict[table_key]\n                table_current = self.state_dict()[table_key]\n                L1, nH1 = table_pretrained.size()\n                L2, nH2 = table_current.size()\n                if nH1 != nH2:\n                    logger.warning(f'Error in loading {table_key}, pass')\n                elif L1 != L2:\n                    S1 = int(L1 ** 0.5)\n                    S2 = int(L2 ** 0.5)\n                    table_pretrained_resized = F.interpolate(\n                        table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n                        size=(S2, S2),\n                        mode='bicubic')\n                    state_dict[table_key] = table_pretrained_resized.view(\n                        nH2, L2).permute(1, 0).contiguous()\n\n            # load state_dict\n            self.load_state_dict(state_dict, False)\n\n    def forward(self, x):\n        x, hw_shape = self.patch_embed(x)\n\n        if self.use_abs_pos_embed:\n            x = x + self.absolute_pos_embed\n        x = self.drop_after_pos(x)\n\n        outs = []\n        for i, stage in enumerate(self.stages):\n            x, hw_shape, out, out_hw_shape = stage(x, hw_shape)\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                out = norm_layer(out)\n                out = out.view(-1, *out_hw_shape,\n                               self.num_features[i]).permute(0, 3, 1,\n                                                             2).contiguous()\n                outs.append(out)\n\n        return outs\n\n\nclass SwinRFPLayer(BaseModule):\n    \"\"\"Implements one stage in Swin Transformer.\n    Args:\n        embed_dims (int): The feature dimension.\n        num_heads (int): Parallel attention heads.\n        feedforward_channels (int): The hidden dimension for FFNs.\n        depth (int): The number of blocks in this stage.\n        window_size (int, optional): The local window scale. Default: 7.\n        qkv_bias (bool, optional): enable bias for qkv if True. Default: True.\n        qk_scale (float | None, optional): Override default qk scale of\n            head_dim ** -0.5 if set. Default: None.\n        drop_rate (float, optional): Dropout rate. Default: 0.\n        attn_drop_rate (float, optional): Attention dropout rate. Default: 0.\n        drop_path_rate (float | list[float], optional): Stochastic depth\n            rate. Default: 0.\n        downsample (BaseModule | None, optional): The downsample operation\n            module. Default: None.\n        act_cfg (dict, optional): The config dict of activation function.\n            Default: dict(type='GELU').\n        norm_cfg (dict, optional): The config dict of normalization.\n            Default: dict(type='LN').\n        with_cp (bool, optional): Use checkpoint or not. Using checkpoint\n            will save some memory while slowing down the training speed.\n            Default: False.\n        init_cfg (dict | list | None, optional): The init config.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 embed_dims,\n                 num_heads,\n                 feedforward_channels,\n                 depth,\n                 window_size=7,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.,\n                 downsample=None,\n                 act_cfg=dict(type='GELU'),\n                 norm_cfg=dict(type='LN'),\n                 with_cp=False,\n                 # Added\n                 rfp_inplanes=None,\n                 # Added Done\n                 init_cfg=None):\n        super().__init__(init_cfg=init_cfg)\n\n        if isinstance(drop_path_rate, list):\n            drop_path_rates = drop_path_rate\n            assert len(drop_path_rates) == depth\n        else:\n            drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]\n\n        self.blocks = ModuleList()\n        for i in range(depth):\n            block = SwinBlock(\n                embed_dims=embed_dims,\n                num_heads=num_heads,\n                feedforward_channels=feedforward_channels,\n                window_size=window_size,\n                shift=False if i % 2 == 0 else True,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop_rate=drop_rate,\n                attn_drop_rate=attn_drop_rate,\n                drop_path_rate=drop_path_rates[i],\n                act_cfg=act_cfg,\n                norm_cfg=norm_cfg,\n                with_cp=with_cp,\n                init_cfg=None)\n            self.blocks.append(block)\n\n        self.downsample = downsample\n\n        self.rfp_inplanes = rfp_inplanes\n        if self.rfp_inplanes:\n            self.rfp_conv = build_conv_layer(\n                None,\n                self.rfp_inplanes,\n                embed_dims,\n                1,\n                stride=1,\n                bias=True)\n\n    def forward(self, x, hw_shape):\n        for block in self.blocks:\n            x = block(x, hw_shape)\n\n        if self.downsample:\n            x_down, down_hw_shape = self.downsample(x, hw_shape)\n            return x_down, down_hw_shape, x, hw_shape\n        else:\n            return x, hw_shape, x, hw_shape\n\n    def rfp_forward(self, x, hw_shape, rfp_feat):\n        for block in self.blocks:\n            x = block(x, hw_shape)\n\n        haw = hw_shape[0] * hw_shape[1]\n        if self.rfp_inplanes:\n            rfp_feat = self.rfp_conv(rfp_feat)\n            x = x + rfp_feat.permute((0, 2, 3, 1)) \\\n                .view(x.shape[0], haw, x.shape[2]).contiguous()\n\n        if self.downsample:\n            x_down, down_hw_shape = self.downsample(x, hw_shape)\n            return x_down, down_hw_shape, x, hw_shape\n        else:\n            return x, hw_shape, x, hw_shape\n\n\n@BACKBONES.register_module()\nclass SwinTransformerRFP(SwinTransformer):\n    def __init__(\n            self,\n            rfp_inplanes=None,\n            output_img=False,\n            # Old settings\n            pretrain_img_size=224,\n            in_channels=3,\n            embed_dims=96,\n            patch_size=4,\n            window_size=7,\n            mlp_ratio=4,\n            depths=(2, 2, 6, 2),\n            num_heads=(3, 6, 12, 24),\n            strides=(4, 2, 2, 2),\n            out_indices=(0, 1, 2, 3),\n            qkv_bias=True,\n            qk_scale=None,\n            patch_norm=True,\n            drop_rate=0.,\n            attn_drop_rate=0.,\n            drop_path_rate=0.1,\n            use_abs_pos_embed=False,\n            act_cfg=dict(type='GELU'),\n            norm_cfg=dict(type='LN'),\n            with_cp=False,\n            pretrained=None,\n            convert_weights=False,\n            frozen_stages=-1,\n            init_cfg=None):\n        self.rfp_inplanes = rfp_inplanes\n        self.output_img = output_img\n        super().__init__(\n            pretrain_img_size=pretrain_img_size,\n            in_channels=in_channels,\n            embed_dims=embed_dims,\n            patch_size=patch_size,\n            window_size=window_size,\n            mlp_ratio=mlp_ratio,\n            depths=depths,\n            num_heads=num_heads,\n            strides=strides,\n            out_indices=out_indices,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            patch_norm=patch_norm,\n            drop_rate=drop_rate,\n            attn_drop_rate=attn_drop_rate,\n            drop_path_rate=drop_path_rate,\n            use_abs_pos_embed=use_abs_pos_embed,\n            act_cfg=act_cfg,\n            norm_cfg=norm_cfg,\n            with_cp=with_cp,\n            pretrained=pretrained,\n            convert_weights=convert_weights,\n            frozen_stages=frozen_stages,\n            init_cfg=init_cfg\n        )\n        # Re-write Swin Block\n        self.stages = ModuleList()\n        in_channels = embed_dims\n        num_layers = len(depths)\n        total_depth = sum(depths)\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, total_depth)\n        ]\n        for i in range(num_layers):\n            if i < num_layers - 1:\n                downsample = PatchMerging(\n                    in_channels=in_channels,\n                    out_channels=2 * in_channels,\n                    stride=strides[i + 1],\n                    norm_cfg=norm_cfg if patch_norm else None,\n                    init_cfg=None)\n            else:\n                downsample = None\n\n            stage = SwinRFPLayer(\n                embed_dims=in_channels,\n                num_heads=num_heads[i],\n                feedforward_channels=mlp_ratio * in_channels,\n                depth=depths[i],\n                window_size=window_size,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop_rate=drop_rate,\n                attn_drop_rate=attn_drop_rate,\n                drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],\n                downsample=downsample,\n                act_cfg=act_cfg,\n                norm_cfg=norm_cfg,\n                with_cp=with_cp,\n                rfp_inplanes=rfp_inplanes if i > 0 else None,\n                init_cfg=None)\n            self.stages.append(stage)\n            if downsample:\n                in_channels = downsample.out_channels\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        outs = list(super().forward(x))\n        if self.output_img:\n            outs.insert(0, x)\n        return tuple(outs)\n\n    def rfp_forward(self, x, rfp_feats):\n        x, hw_shape = self.patch_embed(x)\n\n        if self.use_abs_pos_embed:\n            x = x + self.absolute_pos_embed\n        x = self.drop_after_pos(x)\n\n        outs = []\n        for i, stage in enumerate(self.stages):\n            rfp_feat = rfp_feats[i] if i > 0 else None\n            x, hw_shape, out, out_hw_shape = stage.rfp_forward(x, hw_shape, rfp_feat)\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                out = norm_layer(out)\n                out = out.view(-1, *out_hw_shape,\n                               self.num_features[i]).permute(0, 3, 1,\n                                                             2).contiguous()\n                outs.append(out)\n\n        return outs\n"
  },
  {
    "path": "swin/transformer.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport math\nimport warnings\nfrom typing import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (build_activation_layer, build_conv_layer,\n                      build_norm_layer, xavier_init)\nfrom mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,\n                                      TRANSFORMER_LAYER_SEQUENCE)\nfrom mmcv.cnn.bricks.transformer import (BaseTransformerLayer,\n                                         TransformerLayerSequence,\n                                         build_transformer_layer_sequence)\nfrom mmcv.runner.base_module import BaseModule\nfrom mmcv.utils import to_2tuple\nfrom torch.nn.init import normal_\n\nfrom mmdet.models.utils.builder import TRANSFORMER\n\ntry:\n    from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention\n\nexcept ImportError:\n    warnings.warn(\n        '`MultiScaleDeformableAttention` in MMCV has been moved to '\n        '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')\n    from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention\n\n\ndef nlc_to_nchw(x, hw_shape):\n    \"\"\"Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.\n    Args:\n        x (Tensor): The input tensor of shape [N, L, C] before convertion.\n        hw_shape (Sequence[int]): The height and width of output feature map.\n    Returns:\n        Tensor: The output tensor of shape [N, C, H, W] after convertion.\n    \"\"\"\n    H, W = hw_shape\n    assert len(x.shape) == 3\n    B, L, C = x.shape\n    assert L == H * W, 'The seq_len does not match H, W'\n    return x.transpose(1, 2).reshape(B, C, H, W).contiguous()\n\n\ndef nchw_to_nlc(x):\n    \"\"\"Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.\n    Args:\n        x (Tensor): The input tensor of shape [N, C, H, W] before convertion.\n    Returns:\n        Tensor: The output tensor of shape [N, L, C] after convertion.\n    \"\"\"\n    assert len(x.shape) == 4\n    return x.flatten(2).transpose(1, 2).contiguous()\n\n\nclass AdaptivePadding(nn.Module):\n    \"\"\"Applies padding to input (if needed) so that input can get fully covered\n    by filter you specified. It support two modes \"same\" and \"corner\". The\n    \"same\" mode is same with \"SAME\" padding mode in TensorFlow, pad zero around\n    input. The \"corner\"  mode would pad zero to bottom right.\n    Args:\n        kernel_size (int | tuple): Size of the kernel:\n        stride (int | tuple): Stride of the filter. Default: 1:\n        dilation (int | tuple): Spacing between kernel elements.\n            Default: 1\n        padding (str): Support \"same\" and \"corner\", \"corner\" mode\n            would pad zero to bottom right, and \"same\" mode would\n            pad zero around input. Default: \"corner\".\n    Example:\n        >>> kernel_size = 16\n        >>> stride = 16\n        >>> dilation = 1\n        >>> input = torch.rand(1, 1, 15, 17)\n        >>> adap_pad = AdaptivePadding(\n        >>>     kernel_size=kernel_size,\n        >>>     stride=stride,\n        >>>     dilation=dilation,\n        >>>     padding=\"corner\")\n        >>> out = adap_pad(input)\n        >>> assert (out.shape[2], out.shape[3]) == (16, 32)\n        >>> input = torch.rand(1, 1, 16, 17)\n        >>> out = adap_pad(input)\n        >>> assert (out.shape[2], out.shape[3]) == (16, 32)\n    \"\"\"\n\n    def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):\n\n        super(AdaptivePadding, self).__init__()\n\n        assert padding in ('same', 'corner')\n\n        kernel_size = to_2tuple(kernel_size)\n        stride = to_2tuple(stride)\n        padding = to_2tuple(padding)\n        dilation = to_2tuple(dilation)\n\n        self.padding = padding\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.dilation = dilation\n\n    def get_pad_shape(self, input_shape):\n        input_h, input_w = input_shape\n        kernel_h, kernel_w = self.kernel_size\n        stride_h, stride_w = self.stride\n        output_h = math.ceil(input_h / stride_h)\n        output_w = math.ceil(input_w / stride_w)\n        pad_h = max((output_h - 1) * stride_h +\n                    (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)\n        pad_w = max((output_w - 1) * stride_w +\n                    (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)\n        return pad_h, pad_w\n\n    def forward(self, x):\n        pad_h, pad_w = self.get_pad_shape(x.size()[-2:])\n        if pad_h > 0 or pad_w > 0:\n            if self.padding == 'corner':\n                x = F.pad(x, [0, pad_w, 0, pad_h])\n            elif self.padding == 'same':\n                x = F.pad(x, [\n                    pad_w // 2, pad_w - pad_w // 2, pad_h // 2,\n                    pad_h - pad_h // 2\n                ])\n        return x\n\n\nclass PatchEmbed(BaseModule):\n    \"\"\"Image to Patch Embedding.\n    We use a conv layer to implement PatchEmbed.\n    Args:\n        in_channels (int): The num of input channels. Default: 3\n        embed_dims (int): The dimensions of embedding. Default: 768\n        conv_type (str): The config dict for embedding\n            conv layer type selection. Default: \"Conv2d.\n        kernel_size (int): The kernel_size of embedding conv. Default: 16.\n        stride (int): The slide stride of embedding conv.\n            Default: None (Would be set as `kernel_size`).\n        padding (int | tuple | string ): The padding length of\n            embedding conv. When it is a string, it means the mode\n            of adaptive padding, support \"same\" and \"corner\" now.\n            Default: \"corner\".\n        dilation (int): The dilation rate of embedding conv. Default: 1.\n        bias (bool): Bias of embed conv. Default: True.\n        norm_cfg (dict, optional): Config dict for normalization layer.\n            Default: None.\n        input_size (int | tuple | None): The size of input, which will be\n            used to calculate the out size. Only work when `dynamic_size`\n            is False. Default: None.\n        init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.\n            Default: None.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels=3,\n        embed_dims=768,\n        conv_type='Conv2d',\n        kernel_size=16,\n        stride=16,\n        padding='corner',\n        dilation=1,\n        bias=True,\n        norm_cfg=None,\n        input_size=None,\n        init_cfg=None,\n    ):\n        super(PatchEmbed, self).__init__(init_cfg=init_cfg)\n\n        self.embed_dims = embed_dims\n        if stride is None:\n            stride = kernel_size\n\n        kernel_size = to_2tuple(kernel_size)\n        stride = to_2tuple(stride)\n        dilation = to_2tuple(dilation)\n\n        if isinstance(padding, str):\n            self.adap_padding = AdaptivePadding(\n                kernel_size=kernel_size,\n                stride=stride,\n                dilation=dilation,\n                padding=padding)\n            # disable the padding of conv\n            padding = 0\n        else:\n            self.adap_padding = None\n        padding = to_2tuple(padding)\n\n        self.projection = build_conv_layer(\n            dict(type=conv_type),\n            in_channels=in_channels,\n            out_channels=embed_dims,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            bias=bias)\n\n        if norm_cfg is not None:\n            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]\n        else:\n            self.norm = None\n\n        if input_size:\n            input_size = to_2tuple(input_size)\n            # `init_out_size` would be used outside to\n            # calculate the num_patches\n            # when `use_abs_pos_embed` outside\n            self.init_input_size = input_size\n            if self.adap_padding:\n                pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)\n                input_h, input_w = input_size\n                input_h = input_h + pad_h\n                input_w = input_w + pad_w\n                input_size = (input_h, input_w)\n\n            # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html\n            h_out = (input_size[0] + 2 * padding[0] - dilation[0] *\n                     (kernel_size[0] - 1) - 1) // stride[0] + 1\n            w_out = (input_size[1] + 2 * padding[1] - dilation[1] *\n                     (kernel_size[1] - 1) - 1) // stride[1] + 1\n            self.init_out_size = (h_out, w_out)\n        else:\n            self.init_input_size = None\n            self.init_out_size = None\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x (Tensor): Has shape (B, C, H, W). In most case, C is 3.\n        Returns:\n            tuple: Contains merged results and its spatial shape.\n                - x (Tensor): Has shape (B, out_h * out_w, embed_dims)\n                - out_size (tuple[int]): Spatial shape of x, arrange as\n                    (out_h, out_w).\n        \"\"\"\n\n        if self.adap_padding:\n            x = self.adap_padding(x)\n\n        x = self.projection(x)\n        out_size = (x.shape[2], x.shape[3])\n        x = x.flatten(2).transpose(1, 2)\n        if self.norm is not None:\n            x = self.norm(x)\n        return x, out_size\n\n\nclass PatchMerging(BaseModule):\n    \"\"\"Merge patch feature map.\n    This layer groups feature map by kernel_size, and applies norm and linear\n    layers to the grouped feature map. Our implementation uses `nn.Unfold` to\n    merge patch, which is about 25% faster than original implementation.\n    Instead, we need to modify pretrained models for compatibility.\n    Args:\n        in_channels (int): The num of input channels.\n            to gets fully covered by filter and stride you specified..\n            Default: True.\n        out_channels (int): The num of output channels.\n        kernel_size (int | tuple, optional): the kernel size in the unfold\n            layer. Defaults to 2.\n        stride (int | tuple, optional): the stride of the sliding blocks in the\n            unfold layer. Default: None. (Would be set as `kernel_size`)\n        padding (int | tuple | string ): The padding length of\n            embedding conv. When it is a string, it means the mode\n            of adaptive padding, support \"same\" and \"corner\" now.\n            Default: \"corner\".\n        dilation (int | tuple, optional): dilation parameter in the unfold\n            layer. Default: 1.\n        bias (bool, optional): Whether to add bias in linear layer or not.\n            Defaults: False.\n        norm_cfg (dict, optional): Config dict for normalization layer.\n            Default: dict(type='LN').\n        init_cfg (dict, optional): The extra config for initialization.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size=2,\n                 stride=None,\n                 padding='corner',\n                 dilation=1,\n                 bias=False,\n                 norm_cfg=dict(type='LN'),\n                 init_cfg=None):\n        super().__init__(init_cfg=init_cfg)\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        if stride:\n            stride = stride\n        else:\n            stride = kernel_size\n\n        kernel_size = to_2tuple(kernel_size)\n        stride = to_2tuple(stride)\n        dilation = to_2tuple(dilation)\n\n        if isinstance(padding, str):\n            self.adap_padding = AdaptivePadding(\n                kernel_size=kernel_size,\n                stride=stride,\n                dilation=dilation,\n                padding=padding)\n            # disable the padding of unfold\n            padding = 0\n        else:\n            self.adap_padding = None\n\n        padding = to_2tuple(padding)\n        self.sampler = nn.Unfold(\n            kernel_size=kernel_size,\n            dilation=dilation,\n            padding=padding,\n            stride=stride)\n\n        sample_dim = kernel_size[0] * kernel_size[1] * in_channels\n\n        if norm_cfg is not None:\n            self.norm = build_norm_layer(norm_cfg, sample_dim)[1]\n        else:\n            self.norm = None\n\n        self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)\n\n    def forward(self, x, input_size):\n        \"\"\"\n        Args:\n            x (Tensor): Has shape (B, H*W, C_in).\n            input_size (tuple[int]): The spatial shape of x, arrange as (H, W).\n                Default: None.\n        Returns:\n            tuple: Contains merged results and its spatial shape.\n                - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)\n                - out_size (tuple[int]): Spatial shape of x, arrange as\n                    (Merged_H, Merged_W).\n        \"\"\"\n        B, L, C = x.shape\n        assert isinstance(input_size, Sequence), f'Expect ' \\\n                                                 f'input_size is ' \\\n                                                 f'`Sequence` ' \\\n                                                 f'but get {input_size}'\n\n        H, W = input_size\n        assert L == H * W, 'input feature has wrong size'\n\n        x = x.view(B, H, W, C).permute([0, 3, 1, 2])  # B, C, H, W\n        # Use nn.Unfold to merge patch. About 25% faster than original method,\n        # but need to modify pretrained model for compatibility\n\n        if self.adap_padding:\n            x = self.adap_padding(x)\n            H, W = x.shape[-2:]\n\n        x = self.sampler(x)\n        # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)\n\n        out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *\n                 (self.sampler.kernel_size[0] - 1) -\n                 1) // self.sampler.stride[0] + 1\n        out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *\n                 (self.sampler.kernel_size[1] - 1) -\n                 1) // self.sampler.stride[1] + 1\n\n        output_size = (out_h, out_w)\n        x = x.transpose(1, 2)  # B, H/2*W/2, 4*C\n        x = self.norm(x) if self.norm else x\n        x = self.reduction(x)\n        return x, output_size\n\n\ndef inverse_sigmoid(x, eps=1e-5):\n    \"\"\"Inverse function of sigmoid.\n    Args:\n        x (Tensor): The tensor to do the\n            inverse.\n        eps (float): EPS avoid numerical\n            overflow. Defaults 1e-5.\n    Returns:\n        Tensor: The x has passed the inverse\n            function of sigmoid, has same\n            shape with input.\n    \"\"\"\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1 / x2)\n\n"
  },
  {
    "path": "tools/dataset/cityscapes_instance_idmap.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport argparse\nimport os.path as osp\n\nimport mmcv\nfrom cityscapesscripts.preparation.json2instanceImg import json2instanceImg\n\n\ndef convert_json_to_label(json_file):\n    label_file = json_file.replace('_polygons.json', '_instanceTrainIds.png')\n    json2instanceImg(json_file, label_file, 'trainIds')\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='Convert Cityscapes annotations to TrainIds')\n    parser.add_argument('cityscapes_path', help='cityscapes data path')\n    parser.add_argument('--gt-dir', default='gtFine', type=str)\n    parser.add_argument('-o', '--out-dir', help='output path')\n    parser.add_argument(\n        '--nproc', default=1, type=int, help='number of process')\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = parse_args()\n    cityscapes_path = args.cityscapes_path\n    out_dir = args.out_dir if args.out_dir else cityscapes_path\n    mmcv.mkdir_or_exist(out_dir)\n\n    gt_dir = osp.join(cityscapes_path, args.gt_dir)\n\n    poly_files = []\n    for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True):\n        poly_file = osp.join(gt_dir, poly)\n        poly_files.append(poly_file)\n    if args.nproc > 1:\n        mmcv.track_parallel_progress(convert_json_to_label, poly_files,\n                                     args.nproc)\n    else:\n        mmcv.track_progress(convert_json_to_label, poly_files)\n\n\n# install mmcv and cityscapesscripts\n# python cityscapes_instance.py {PATH/TO/CITYSCAPES} --nproc 56\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/dataset/youtubevis2coco.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport argparse\nimport copy\nimport os\nimport os.path as osp\nfrom collections import defaultdict\n\nimport mmcv\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='YouTube-VIS to COCO Video format')\n    parser.add_argument(\n        '-i',\n        '--input',\n        help='root directory of YouTube-VIS annotations',\n    )\n    parser.add_argument(\n        '-o',\n        '--output',\n        help='directory to save coco formatted label file',\n    )\n    parser.add_argument(\n        '--version',\n        choices=['2019', '2021'],\n        help='The version of YouTube-VIS Dataset',\n    )\n    return parser.parse_args()\n\n\ndef convert_vis(ann_dir, save_dir, dataset_version, mode='train'):\n    \"\"\"Convert YouTube-VIS dataset in COCO style.\n    Args:\n        ann_dir (str): The path of YouTube-VIS dataset.\n        save_dir (str): The path to save `VIS`.\n        dataset_version (str): The version of dataset. Options are '2019',\n            '2021'.\n        mode (str): Convert train dataset or validation dataset or test\n            dataset. Options are 'train', 'valid', 'test'. Default: 'train'.\n    \"\"\"\n    assert dataset_version in ['2019', '2021']\n    assert mode in ['train', 'valid', 'test']\n    VIS = defaultdict(list)\n    records = dict(vid_id=1, img_id=1, ann_id=1, global_instance_id=1)\n    obj_num_classes = dict()\n\n    if dataset_version == '2019':\n        official_anns = mmcv.load(osp.join(ann_dir, f'{mode}.json'))\n    elif dataset_version == '2021':\n        official_anns = mmcv.load(osp.join(ann_dir, mode, 'instances.json'))\n    VIS['categories'] = copy.deepcopy(official_anns['categories'])\n\n    has_annotations = mode == 'train'\n    if has_annotations:\n        vid_to_anns = defaultdict(list)\n        for ann_info in official_anns['annotations']:\n            vid_to_anns[ann_info['video_id']].append(ann_info)\n\n    video_infos = official_anns['videos']\n    for video_info in video_infos:\n        video_name = video_info['file_names'][0].split('/')[0]\n        video = dict(id=video_info['id'], name=video_name)\n        VIS['videos'].append(video)\n\n        num_frames = len(video_info['file_names'])\n        width = video_info['width']\n        height = video_info['height']\n        if has_annotations:\n            ann_infos_in_video = vid_to_anns[video_info['id']]\n            instance_id_maps = dict()\n\n        for frame_id in range(num_frames):\n            image = dict(\n                file_name=video_info['file_names'][frame_id],\n                height=height,\n                width=width,\n                id=records['img_id'],\n                frame_id=frame_id,\n                video_id=video_info['id'])\n            VIS['images'].append(image)\n\n            if has_annotations:\n                for ann_info in ann_infos_in_video:\n                    bbox = ann_info['bboxes'][frame_id]\n                    if bbox is None:\n                        continue\n\n                    category_id = ann_info['category_id']\n                    track_id = ann_info['id']\n                    segmentation = ann_info['segmentations'][frame_id]\n                    area = ann_info['areas'][frame_id]\n                    assert isinstance(category_id, int)\n                    assert isinstance(track_id, int)\n                    assert segmentation is not None\n                    assert area is not None\n\n                    if track_id in instance_id_maps:\n                        instance_id = instance_id_maps[track_id]\n                    else:\n                        instance_id = records['global_instance_id']\n                        records['global_instance_id'] += 1\n                        instance_id_maps[track_id] = instance_id\n\n                    ann = dict(\n                        id=records['ann_id'],\n                        video_id=video_info['id'],\n                        image_id=records['img_id'],\n                        category_id=category_id,\n                        instance_id=instance_id,\n                        bbox=bbox,\n                        segmentation=segmentation,\n                        area=area,\n                        iscrowd=ann_info['iscrowd'])\n\n                    if category_id not in obj_num_classes:\n                        obj_num_classes[category_id] = 1\n                    else:\n                        obj_num_classes[category_id] += 1\n\n                    VIS['annotations'].append(ann)\n                    records['ann_id'] += 1\n            records['img_id'] += 1\n        records['vid_id'] += 1\n\n    if not osp.isdir(save_dir):\n        os.makedirs(save_dir)\n    mmcv.dump(VIS,\n              osp.join(save_dir, f'youtube_vis_{dataset_version}_{mode}.json'))\n    print(f'-----YouTube VIS {dataset_version} {mode}------')\n    print(f'{records[\"vid_id\"]- 1} videos')\n    print(f'{records[\"img_id\"]- 1} images')\n    if has_annotations:\n        print(f'{records[\"ann_id\"] - 1} objects')\n        print(f'{records[\"global_instance_id\"] - 1} instances')\n    print('-----------------------')\n    if has_annotations:\n        for i in range(1, len(VIS['categories']) + 1):\n            class_name = VIS['categories'][i - 1]['name']\n            print(f'Class {i} {class_name} has {obj_num_classes[i]} objects.')\n\n\ndef main():\n    args = parse_args()\n    for sub_set in ['train', 'valid', 'test']:\n        convert_vis(args.input, args.output, args.version, sub_set)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/dist_step_test.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=${PORT:-29500}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/test_step.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}"
  },
  {
    "path": "tools/dist_test.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=${PORT:-29500}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}\n"
  },
  {
    "path": "tools/dist_train.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nGPUS=$2\nPORT=${PORT:-$((29500 + $RANDOM % 29))}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch  --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/train.py $CONFIG --launcher pytorch ${@:3}\n"
  },
  {
    "path": "tools/dist_train_new.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nGPUS=$2\nPORT=${PORT:-$((29500 + $RANDOM % 29))}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.run  --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/train_new.py $CONFIG --launcher pytorch ${@:3}\n"
  },
  {
    "path": "tools/dist_vps_test.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=${PORT:-29500}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/test_vps.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}"
  },
  {
    "path": "tools/docker.sh",
    "content": "#!/bin/bash\n\nDATALOC=${DATALOC:-~/datasets}\nLOGLOC=${LOGLOC:-~/logger}\nIMG=${IMG:-\"harbory/openmmlab:latest\"}\n\ndocker run --gpus all -it --rm --ipc=host --net=host -v $(pwd):/data -v $DATALOC:/data/data -v $LOGLOC:/data/logger $IMG\n"
  },
  {
    "path": "tools/eval_dstq.py",
    "content": "import argparse\nimport os\n\nimport mmcv\nimport numpy as np\nimport torch\nfrom mmcv import ProgressBar\n\nimport torch.nn.functional as F\n\nfrom tools.utils.DSTQ import DSTQuality\nfrom tools.utils.STQ import STQuality\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Evaluation of DSTQ')\n    parser.add_argument('result_path')\n    parser.add_argument('--gt-path', default='data/kitti-dvps')\n    parser.add_argument('--split', default='val')\n    parser.add_argument(\n        '--depth',\n        action='store_true',\n        help='eval depth')\n    parser.add_argument('--nproc', default=1, type=int, help='number of process')\n    args = parser.parse_args()\n    return args\n\n\ndef updater(pred_ins_name,\n            pred_cls_name,\n            pred_dep_name,\n            gt_cls_seq_name,\n            gt_ins_seq_name,\n            gt_dep_seq_name,\n            updater_obj,\n            seq_id):\n    pred_ins = mmcv.imread(pred_ins_name, flag='unchanged').astype(np.int32)\n    pred_cls = mmcv.imread(pred_cls_name, flag='unchanged').astype(np.int32)\n    pred_dep = mmcv.imread(pred_dep_name, flag='unchanged').astype(np.float32) if pred_dep_name is not None else None\n\n    gt_ins = mmcv.imread(gt_ins_seq_name, flag='unchanged').astype(np.int32)\n    gt_cls = mmcv.imread(gt_cls_seq_name, flag='unchanged').astype(np.int32)\n    gt_dep = mmcv.imread(gt_dep_seq_name, flag='unchanged').astype(np.float32) if gt_dep_seq_name is not None else None\n    if pred_dep is not None:\n        pred_dep = F.interpolate(torch.from_numpy(pred_dep)[None][None], size=gt_dep.shape)[0][0].numpy()\n\n    valid_mask_seg = gt_cls != 255\n\n    pred_masked_ps = pred_cls[valid_mask_seg] * (2 ** 16) + pred_ins[valid_mask_seg]\n    gt_masked_ps = gt_cls[valid_mask_seg] * (2 ** 16) + gt_ins[valid_mask_seg]\n\n    if pred_dep_name is not None:\n        valid_mask_dep = gt_dep > 0.\n\n        pred_masked_depth = pred_dep[valid_mask_dep]\n        gt_masked_depth = gt_dep[valid_mask_dep]\n\n        updater_obj.update_state(gt_masked_ps, pred_masked_ps, gt_masked_depth, pred_masked_depth, seq_id)\n    else:\n        updater_obj.update_state(gt_masked_ps, pred_masked_ps, seq_id)\n\n\ndef eval_dstq(result_dir, gt_dir, seq_ids, with_depth=True):\n    if with_depth:\n        dstq_obj = DSTQuality(\n            num_classes=19,\n            things_list=list(range(8)),\n            ignore_label=255,\n            label_bit_shift=16,\n            offset=2 ** 16 * 256,\n            depth_threshold=(1.25,),\n        )\n    else:\n        dstq_obj = STQuality(\n            num_classes=19,\n            things_list=list(range(8)),\n            ignore_label=255,\n            label_bit_shift=16,\n            offset=2 ** 16 * 256,\n        )\n\n    gt_names = list(mmcv.scandir(gt_dir))\n    gt_cls_names = sorted(list(filter(lambda x: 'gtFine_class' in x, gt_names)))\n    gt_ins_names = sorted(list(filter(lambda x: 'gtFine_instance' in x, gt_names)))\n    if with_depth:\n        gt_dep_names = sorted(list(filter(lambda x: 'depth' in x, gt_names)))\n    else:\n        gt_dep_names = None\n\n    for seq_id in seq_ids:\n        pred_name_panoptic = list(mmcv.scandir(os.path.join(result_dir, 'panoptic', str(seq_id))))\n        pred_ins_names = sorted(list(filter(lambda x: 'ins' in x, pred_name_panoptic)))\n        pred_cls_names = sorted(list(filter(lambda x: 'cat' in x, pred_name_panoptic)))\n        if with_depth:\n            pred_name_depth = list(mmcv.scandir(os.path.join(result_dir, 'depth', str(seq_id))))\n            pred_dep_names = sorted(pred_name_depth)\n        else:\n            pred_dep_names = [None] * len(pred_ins_names)\n        gt_cls_seq_names = list(filter(lambda x: x.startswith('{:06d}'.format(seq_id)), gt_cls_names))\n        gt_ins_seq_names = list(filter(lambda x: x.startswith('{:06d}'.format(seq_id)), gt_ins_names))\n        if with_depth:\n            gt_dep_seq_names = list(filter(lambda x: x.startswith('{:06d}'.format(seq_id)), gt_dep_names))\n        else:\n            gt_dep_seq_names = [None] * len(gt_cls_seq_names)\n        prog_bar = ProgressBar(len(pred_ins_names))\n        for pred_ins_name, pred_cls_name, pred_dep_name, gt_cls_seq_name, gt_ins_seq_name, gt_dep_seq_name in zip(\n                pred_ins_names, pred_cls_names, pred_dep_names, gt_cls_seq_names, gt_ins_seq_names, gt_dep_seq_names\n        ):\n            prog_bar.update()\n            updater(\n                os.path.join(result_dir, 'panoptic', str(seq_id), pred_ins_name),\n                os.path.join(result_dir, 'panoptic', str(seq_id), pred_cls_name),\n                os.path.join(result_dir, 'depth', str(seq_id), pred_dep_name) if pred_dep_name is not None else None,\n                os.path.join(gt_dir, gt_cls_seq_name),\n                os.path.join(gt_dir, gt_ins_seq_name),\n                os.path.join(gt_dir, gt_dep_seq_name) if gt_dep_seq_name is not None else None,\n                dstq_obj,\n                seq_id\n            )\n    result = dstq_obj.result()\n    print(result)\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    result_path = args.result_path\n    gt_path = args.gt_path\n    split = args.split\n    eval_dstq(result_path, os.path.join(gt_path, 'video_sequence', split), [8], args.depth)\n"
  },
  {
    "path": "tools/eval_dstq_step.py",
    "content": "import argparse\nimport os\n\nimport mmcv\nimport numpy as np\nimport torch\nfrom mmcv import ProgressBar\n\nimport torch.nn.functional as F\n\nfrom tools.utils.DSTQ import DSTQuality\nfrom tools.utils.STQ import STQuality\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Evaluation of DSTQ')\n    parser.add_argument('result_path')\n    parser.add_argument('--gt-path', default='data/kitti-step')\n    parser.add_argument('--split', default='val')\n    parser.add_argument(\n        '--depth',\n        action='store_true',\n        help='eval depth')\n    parser.add_argument('--nproc', default=1, type=int, help='number of process')\n    args = parser.parse_args()\n    return args\n\n\ndef updater(pred_ins_name,\n            pred_cls_name,\n            pred_dep_name,\n            gt_pan_seq_name,\n            gt_dep_seq_name,\n            updater_obj,\n            seq_id):\n    pred_ins = mmcv.imread(pred_ins_name, flag='unchanged').astype(np.int32)\n    pred_cls = mmcv.imread(pred_cls_name, flag='unchanged').astype(np.int32)\n    pred_dep = mmcv.imread(pred_dep_name, flag='unchanged').astype(np.float32) if pred_dep_name is not None else None\n\n    gt_pan = mmcv.imread(gt_pan_seq_name, flag='color', channel_order='rgb')\n    gt_cls = gt_pan[..., 0].astype(np.int32)\n    gt_ins = gt_pan[..., 1].astype(np.int32) * 256 + gt_pan[..., 2].astype(np.int32)\n    gt_dep = mmcv.imread(gt_dep_seq_name, flag='unchanged').astype(np.float32) if gt_dep_seq_name is not None else None\n    if pred_dep is not None:\n        pred_dep = F.interpolate(torch.from_numpy(pred_dep)[None][None], size=gt_dep.shape)[0][0].numpy()\n\n    valid_mask_seg = gt_cls != 255\n\n    pred_masked_ps = pred_cls[valid_mask_seg] * (2 ** 16) + pred_ins[valid_mask_seg]\n    gt_masked_ps = gt_cls[valid_mask_seg] * (2 ** 16) + gt_ins[valid_mask_seg]\n\n    if pred_dep_name is not None:\n        valid_mask_dep = gt_dep > 0.\n\n        pred_masked_depth = pred_dep[valid_mask_dep]\n        gt_masked_depth = gt_dep[valid_mask_dep]\n\n        updater_obj.update_state(gt_masked_ps, pred_masked_ps, gt_masked_depth, pred_masked_depth, seq_id)\n    else:\n        updater_obj.update_state(gt_masked_ps, pred_masked_ps, seq_id)\n\n\ndef eval_dstq(result_dir, gt_dir, seq_ids, with_depth=True):\n    if with_depth:\n        dstq_obj = DSTQuality(\n            num_classes=19,\n            things_list=list(range(11, 19)),\n            ignore_label=255,\n            label_bit_shift=16,\n            offset=2 ** 16 * 256,\n            depth_threshold=(1.25,),\n        )\n    else:\n        dstq_obj = STQuality(\n            num_classes=19,\n            things_list=list(range(11, 19)),\n            ignore_label=255,\n            label_bit_shift=16,\n            offset=2 ** 16 * 256,\n        )\n\n    gt_names = list(mmcv.scandir(gt_dir))\n    gt_pan_names = sorted(list(filter(lambda x: 'panoptic' in x, gt_names)))\n    if with_depth:\n        gt_dep_names = sorted(list(filter(lambda x: 'depth' in x, gt_names)))\n    else:\n        gt_dep_names = None\n\n    for seq_id in seq_ids:\n        pred_name_panoptic = list(mmcv.scandir(os.path.join(result_dir, 'panoptic', str(seq_id))))\n        pred_ins_names = sorted(list(filter(lambda x: 'ins' in x, pred_name_panoptic)))\n        pred_cls_names = sorted(list(filter(lambda x: 'cat' in x, pred_name_panoptic)))\n        if with_depth:\n            pred_name_depth = list(mmcv.scandir(os.path.join(result_dir, 'depth', str(seq_id))))\n            pred_dep_names = sorted(pred_name_depth)\n        else:\n            pred_dep_names = [None] * len(pred_ins_names)\n        gt_pan_seq_names = list(filter(lambda x: x.startswith('{:06d}'.format(seq_id)), gt_pan_names))\n        if with_depth:\n            gt_dep_seq_names = list(filter(lambda x: x.startswith('{:06d}'.format(seq_id)), gt_dep_names))\n        else:\n            gt_dep_seq_names = [None] * len(gt_pan_seq_names)\n        prog_bar = ProgressBar(len(pred_ins_names))\n        for pred_ins_name, pred_cls_name, pred_dep_name, gt_pan_seq_name, gt_dep_seq_name in zip(\n                pred_ins_names, pred_cls_names, pred_dep_names, gt_pan_seq_names, gt_dep_seq_names\n        ):\n            prog_bar.update()\n            updater(\n                os.path.join(result_dir, 'panoptic', str(seq_id), pred_ins_name),\n                os.path.join(result_dir, 'panoptic', str(seq_id), pred_cls_name),\n                os.path.join(result_dir, 'depth', str(seq_id), pred_dep_name) if pred_dep_name is not None else None,\n                os.path.join(gt_dir, gt_pan_seq_name),\n                os.path.join(gt_dir, gt_dep_seq_name) if gt_dep_seq_name is not None else None,\n                dstq_obj,\n                seq_id\n            )\n    result = dstq_obj.result()\n    print(result)\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    result_path = args.result_path\n    gt_path = args.gt_path\n    split = args.split\n    eval_dstq(result_path, os.path.join(gt_path, 'video_sequence', split), [2, 6, 7, 8, 10, 13, 14, 16, 18], args.depth)\n"
  },
  {
    "path": "tools/eval_dstq_vipseg.py",
    "content": "import argparse\nimport os\n\nimport mmcv\nimport numpy as np\nimport torch\nfrom mmcv import ProgressBar\n\nimport torch.nn.functional as F\n\nfrom tools.utils.DSTQ import DSTQuality\nfrom tools.utils.STQ import STQuality\n\nCLASSES = [\n    {\"id\": 0, \"name\": \"wall\", \"isthing\": 0, \"color\": [120, 120, 120]},\n    {\"id\": 1, \"name\": \"ceiling\", \"isthing\": 0, \"color\": [180, 120, 120]},\n    {\"id\": 2, \"name\": \"door\", \"isthing\": 1, \"color\": [6, 230, 230]},\n    {\"id\": 3, \"name\": \"stair\", \"isthing\": 0, \"color\": [80, 50, 50]},\n    {\"id\": 4, \"name\": \"ladder\", \"isthing\": 1, \"color\": [4, 200, 3]},\n    {\"id\": 5, \"name\": \"escalator\", \"isthing\": 0, \"color\": [120, 120, 80]},\n    {\"id\": 6, \"name\": \"Playground_slide\", \"isthing\": 0, \"color\": [140, 140, 140]},\n    {\"id\": 7, \"name\": \"handrail_or_fence\", \"isthing\": 0, \"color\": [204, 5, 255]},\n    {\"id\": 8, \"name\": \"window\", \"isthing\": 1, \"color\": [230, 230, 230]},\n    {\"id\": 9, \"name\": \"rail\", \"isthing\": 0, \"color\": [4, 250, 7]},\n    {\"id\": 10, \"name\": \"goal\", \"isthing\": 1, \"color\": [224, 5, 255]},\n    {\"id\": 11, \"name\": \"pillar\", \"isthing\": 0, \"color\": [235, 255, 7]},\n    {\"id\": 12, \"name\": \"pole\", \"isthing\": 0, \"color\": [150, 5, 61]},\n    {\"id\": 13, \"name\": \"floor\", \"isthing\": 0, \"color\": [120, 120, 70]},\n    {\"id\": 14, \"name\": \"ground\", \"isthing\": 0, \"color\": [8, 255, 51]},\n    {\"id\": 15, \"name\": \"grass\", \"isthing\": 0, \"color\": [255, 6, 82]},\n    {\"id\": 16, \"name\": \"sand\", \"isthing\": 0, \"color\": [143, 255, 140]},\n    {\"id\": 17, \"name\": \"athletic_field\", \"isthing\": 0, \"color\": [204, 255, 4]},\n    {\"id\": 18, \"name\": \"road\", \"isthing\": 0, \"color\": [255, 51, 7]},\n    {\"id\": 19, \"name\": \"path\", \"isthing\": 0, \"color\": [204, 70, 3]},\n    {\"id\": 20, \"name\": \"crosswalk\", \"isthing\": 0, \"color\": [0, 102, 200]},\n    {\"id\": 21, \"name\": \"building\", \"isthing\": 0, \"color\": [61, 230, 250]},\n    {\"id\": 22, \"name\": \"house\", \"isthing\": 0, \"color\": [255, 6, 51]},\n    {\"id\": 23, \"name\": \"bridge\", \"isthing\": 0, \"color\": [11, 102, 255]},\n    {\"id\": 24, \"name\": \"tower\", \"isthing\": 0, \"color\": [255, 7, 71]},\n    {\"id\": 25, \"name\": \"windmill\", \"isthing\": 0, \"color\": [255, 9, 224]},\n    {\"id\": 26, \"name\": \"well_or_well_lid\", \"isthing\": 0, \"color\": [9, 7, 230]},\n    {\"id\": 27, \"name\": \"other_construction\", \"isthing\": 0, \"color\": [220, 220, 220]},\n    {\"id\": 28, \"name\": \"sky\", \"isthing\": 0, \"color\": [255, 9, 92]},\n    {\"id\": 29, \"name\": \"mountain\", \"isthing\": 0, \"color\": [112, 9, 255]},\n    {\"id\": 30, \"name\": \"stone\", \"isthing\": 0, \"color\": [8, 255, 214]},\n    {\"id\": 31, \"name\": \"wood\", \"isthing\": 0, \"color\": [7, 255, 224]},\n    {\"id\": 32, \"name\": \"ice\", \"isthing\": 0, \"color\": [255, 184, 6]},\n    {\"id\": 33, \"name\": \"snowfield\", \"isthing\": 0, \"color\": [10, 255, 71]},\n    {\"id\": 34, \"name\": \"grandstand\", \"isthing\": 0, \"color\": [255, 41, 10]},\n    {\"id\": 35, \"name\": \"sea\", \"isthing\": 0, \"color\": [7, 255, 255]},\n    {\"id\": 36, \"name\": \"river\", \"isthing\": 0, \"color\": [224, 255, 8]},\n    {\"id\": 37, \"name\": \"lake\", \"isthing\": 0, \"color\": [102, 8, 255]},\n    {\"id\": 38, \"name\": \"waterfall\", \"isthing\": 0, \"color\": [255, 61, 6]},\n    {\"id\": 39, \"name\": \"water\", \"isthing\": 0, \"color\": [255, 194, 7]},\n    {\"id\": 40, \"name\": \"billboard_or_Bulletin_Board\", \"isthing\": 0, \"color\": [255, 122, 8]},\n    {\"id\": 41, \"name\": \"sculpture\", \"isthing\": 1, \"color\": [0, 255, 20]},\n    {\"id\": 42, \"name\": \"pipeline\", \"isthing\": 0, \"color\": [255, 8, 41]},\n    {\"id\": 43, \"name\": \"flag\", \"isthing\": 1, \"color\": [255, 5, 153]},\n    {\"id\": 44, \"name\": \"parasol_or_umbrella\", \"isthing\": 1, \"color\": [6, 51, 255]},\n    {\"id\": 45, \"name\": \"cushion_or_carpet\", \"isthing\": 0, \"color\": [235, 12, 255]},\n    {\"id\": 46, \"name\": \"tent\", \"isthing\": 1, \"color\": [160, 150, 20]},\n    {\"id\": 47, \"name\": \"roadblock\", \"isthing\": 1, \"color\": [0, 163, 255]},\n    {\"id\": 48, \"name\": \"car\", \"isthing\": 1, \"color\": [140, 140, 140]},\n    {\"id\": 49, \"name\": \"bus\", \"isthing\": 1, \"color\": [250, 10, 15]},\n    {\"id\": 50, \"name\": \"truck\", \"isthing\": 1, \"color\": [20, 255, 0]},\n    {\"id\": 51, \"name\": \"bicycle\", \"isthing\": 1, \"color\": [31, 255, 0]},\n    {\"id\": 52, \"name\": \"motorcycle\", \"isthing\": 1, \"color\": [255, 31, 0]},\n    {\"id\": 53, \"name\": \"wheeled_machine\", \"isthing\": 0, \"color\": [255, 224, 0]},\n    {\"id\": 54, \"name\": \"ship_or_boat\", \"isthing\": 1, \"color\": [153, 255, 0]},\n    {\"id\": 55, \"name\": \"raft\", \"isthing\": 1, \"color\": [0, 0, 255]},\n    {\"id\": 56, \"name\": \"airplane\", \"isthing\": 1, \"color\": [255, 71, 0]},\n    {\"id\": 57, \"name\": \"tyre\", \"isthing\": 0, \"color\": [0, 235, 255]},\n    {\"id\": 58, \"name\": \"traffic_light\", \"isthing\": 0, \"color\": [0, 173, 255]},\n    {\"id\": 59, \"name\": \"lamp\", \"isthing\": 0, \"color\": [31, 0, 255]},\n    {\"id\": 60, \"name\": \"person\", \"isthing\": 1, \"color\": [11, 200, 200]},\n    {\"id\": 61, \"name\": \"cat\", \"isthing\": 1, \"color\": [255, 82, 0]},\n    {\"id\": 62, \"name\": \"dog\", \"isthing\": 1, \"color\": [0, 255, 245]},\n    {\"id\": 63, \"name\": \"horse\", \"isthing\": 1, \"color\": [0, 61, 255]},\n    {\"id\": 64, \"name\": \"cattle\", \"isthing\": 1, \"color\": [0, 255, 112]},\n    {\"id\": 65, \"name\": \"other_animal\", \"isthing\": 1, \"color\": [0, 255, 133]},\n    {\"id\": 66, \"name\": \"tree\", \"isthing\": 0, \"color\": [255, 0, 0]},\n    {\"id\": 67, \"name\": \"flower\", \"isthing\": 0, \"color\": [255, 163, 0]},\n    {\"id\": 68, \"name\": \"other_plant\", \"isthing\": 0, \"color\": [255, 102, 0]},\n    {\"id\": 69, \"name\": \"toy\", \"isthing\": 0, \"color\": [194, 255, 0]},\n    {\"id\": 70, \"name\": \"ball_net\", \"isthing\": 0, \"color\": [0, 143, 255]},\n    {\"id\": 71, \"name\": \"backboard\", \"isthing\": 0, \"color\": [51, 255, 0]},\n    {\"id\": 72, \"name\": \"skateboard\", \"isthing\": 1, \"color\": [0, 82, 255]},\n    {\"id\": 73, \"name\": \"bat\", \"isthing\": 0, \"color\": [0, 255, 41]},\n    {\"id\": 74, \"name\": \"ball\", \"isthing\": 1, \"color\": [0, 255, 173]},\n    {\"id\": 75, \"name\": \"cupboard_or_showcase_or_storage_rack\", \"isthing\": 0, \"color\": [10, 0, 255]},\n    {\"id\": 76, \"name\": \"box\", \"isthing\": 1, \"color\": [173, 255, 0]},\n    {\"id\": 77, \"name\": \"traveling_case_or_trolley_case\", \"isthing\": 1, \"color\": [0, 255, 153]},\n    {\"id\": 78, \"name\": \"basket\", \"isthing\": 1, \"color\": [255, 92, 0]},\n    {\"id\": 79, \"name\": \"bag_or_package\", \"isthing\": 1, \"color\": [255, 0, 255]},\n    {\"id\": 80, \"name\": \"trash_can\", \"isthing\": 0, \"color\": [255, 0, 245]},\n    {\"id\": 81, \"name\": \"cage\", \"isthing\": 0, \"color\": [255, 0, 102]},\n    {\"id\": 82, \"name\": \"plate\", \"isthing\": 1, \"color\": [255, 173, 0]},\n    {\"id\": 83, \"name\": \"tub_or_bowl_or_pot\", \"isthing\": 1, \"color\": [255, 0, 20]},\n    {\"id\": 84, \"name\": \"bottle_or_cup\", \"isthing\": 1, \"color\": [255, 184, 184]},\n    {\"id\": 85, \"name\": \"barrel\", \"isthing\": 1, \"color\": [0, 31, 255]},\n    {\"id\": 86, \"name\": \"fishbowl\", \"isthing\": 1, \"color\": [0, 255, 61]},\n    {\"id\": 87, \"name\": \"bed\", \"isthing\": 1, \"color\": [0, 71, 255]},\n    {\"id\": 88, \"name\": \"pillow\", \"isthing\": 1, \"color\": [255, 0, 204]},\n    {\"id\": 89, \"name\": \"table_or_desk\", \"isthing\": 1, \"color\": [0, 255, 194]},\n    {\"id\": 90, \"name\": \"chair_or_seat\", \"isthing\": 1, \"color\": [0, 255, 82]},\n    {\"id\": 91, \"name\": \"bench\", \"isthing\": 1, \"color\": [0, 10, 255]},\n    {\"id\": 92, \"name\": \"sofa\", \"isthing\": 1, \"color\": [0, 112, 255]},\n    {\"id\": 93, \"name\": \"shelf\", \"isthing\": 0, \"color\": [51, 0, 255]},\n    {\"id\": 94, \"name\": \"bathtub\", \"isthing\": 0, \"color\": [0, 194, 255]},\n    {\"id\": 95, \"name\": \"gun\", \"isthing\": 1, \"color\": [0, 122, 255]},\n    {\"id\": 96, \"name\": \"commode\", \"isthing\": 1, \"color\": [0, 255, 163]},\n    {\"id\": 97, \"name\": \"roaster\", \"isthing\": 1, \"color\": [255, 153, 0]},\n    {\"id\": 98, \"name\": \"other_machine\", \"isthing\": 0, \"color\": [0, 255, 10]},\n    {\"id\": 99, \"name\": \"refrigerator\", \"isthing\": 1, \"color\": [255, 112, 0]},\n    {\"id\": 100, \"name\": \"washing_machine\", \"isthing\": 1, \"color\": [143, 255, 0]},\n    {\"id\": 101, \"name\": \"Microwave_oven\", \"isthing\": 1, \"color\": [82, 0, 255]},\n    {\"id\": 102, \"name\": \"fan\", \"isthing\": 1, \"color\": [163, 255, 0]},\n    {\"id\": 103, \"name\": \"curtain\", \"isthing\": 0, \"color\": [255, 235, 0]},\n    {\"id\": 104, \"name\": \"textiles\", \"isthing\": 0, \"color\": [8, 184, 170]},\n    {\"id\": 105, \"name\": \"clothes\", \"isthing\": 0, \"color\": [133, 0, 255]},\n    {\"id\": 106, \"name\": \"painting_or_poster\", \"isthing\": 1, \"color\": [0, 255, 92]},\n    {\"id\": 107, \"name\": \"mirror\", \"isthing\": 1, \"color\": [184, 0, 255]},\n    {\"id\": 108, \"name\": \"flower_pot_or_vase\", \"isthing\": 1, \"color\": [255, 0, 31]},\n    {\"id\": 109, \"name\": \"clock\", \"isthing\": 1, \"color\": [0, 184, 255]},\n    {\"id\": 110, \"name\": \"book\", \"isthing\": 0, \"color\": [0, 214, 255]},\n    {\"id\": 111, \"name\": \"tool\", \"isthing\": 0, \"color\": [255, 0, 112]},\n    {\"id\": 112, \"name\": \"blackboard\", \"isthing\": 0, \"color\": [92, 255, 0]},\n    {\"id\": 113, \"name\": \"tissue\", \"isthing\": 0, \"color\": [0, 224, 255]},\n    {\"id\": 114, \"name\": \"screen_or_television\", \"isthing\": 1, \"color\": [112, 224, 255]},\n    {\"id\": 115, \"name\": \"computer\", \"isthing\": 1, \"color\": [70, 184, 160]},\n    {\"id\": 116, \"name\": \"printer\", \"isthing\": 1, \"color\": [163, 0, 255]},\n    {\"id\": 117, \"name\": \"Mobile_phone\", \"isthing\": 1, \"color\": [153, 0, 255]},\n    {\"id\": 118, \"name\": \"keyboard\", \"isthing\": 1, \"color\": [71, 255, 0]},\n    {\"id\": 119, \"name\": \"other_electronic_product\", \"isthing\": 0, \"color\": [255, 0, 163]},\n    {\"id\": 120, \"name\": \"fruit\", \"isthing\": 0, \"color\": [255, 204, 0]},\n    {\"id\": 121, \"name\": \"food\", \"isthing\": 0, \"color\": [255, 0, 143]},\n    {\"id\": 122, \"name\": \"instrument\", \"isthing\": 1, \"color\": [0, 255, 235]},\n    {\"id\": 123, \"name\": \"train\", \"isthing\": 1, \"color\": [133, 255, 0]}\n]\n\nCLASSES_THING = [\n    {'id': 2, 'name': 'door', 'isthing': 1, 'color': [6, 230, 230]},\n    {'id': 4, 'name': 'ladder', 'isthing': 1, 'color': [4, 200, 3]},\n    {'id': 8, 'name': 'window', 'isthing': 1, 'color': [230, 230, 230]},\n    {'id': 10, 'name': 'goal', 'isthing': 1, 'color': [224, 5, 255]},\n    {'id': 41, 'name': 'sculpture', 'isthing': 1, 'color': [0, 255, 20]},\n    {'id': 43, 'name': 'flag', 'isthing': 1, 'color': [255, 5, 153]},\n    {'id': 44, 'name': 'parasol_or_umbrella', 'isthing': 1, 'color': [6, 51, 255]},\n    {'id': 46, 'name': 'tent', 'isthing': 1, 'color': [160, 150, 20]},\n    {'id': 47, 'name': 'roadblock', 'isthing': 1, 'color': [0, 163, 255]},\n    {'id': 48, 'name': 'car', 'isthing': 1, 'color': [140, 140, 140]},\n    {'id': 49, 'name': 'bus', 'isthing': 1, 'color': [250, 10, 15]},\n    {'id': 50, 'name': 'truck', 'isthing': 1, 'color': [20, 255, 0]},\n    {'id': 51, 'name': 'bicycle', 'isthing': 1, 'color': [31, 255, 0]},\n    {'id': 52, 'name': 'motorcycle', 'isthing': 1, 'color': [255, 31, 0]},\n    {'id': 54, 'name': 'ship_or_boat', 'isthing': 1, 'color': [153, 255, 0]},\n    {'id': 55, 'name': 'raft', 'isthing': 1, 'color': [0, 0, 255]},\n    {'id': 56, 'name': 'airplane', 'isthing': 1, 'color': [255, 71, 0]},\n    {'id': 60, 'name': 'person', 'isthing': 1, 'color': [11, 200, 200]},\n    {'id': 61, 'name': 'cat', 'isthing': 1, 'color': [255, 82, 0]},\n    {'id': 62, 'name': 'dog', 'isthing': 1, 'color': [0, 255, 245]},\n    {'id': 63, 'name': 'horse', 'isthing': 1, 'color': [0, 61, 255]},\n    {'id': 64, 'name': 'cattle', 'isthing': 1, 'color': [0, 255, 112]},\n    {'id': 65, 'name': 'other_animal', 'isthing': 1, 'color': [0, 255, 133]},\n    {'id': 72, 'name': 'skateboard', 'isthing': 1, 'color': [0, 82, 255]},\n    {'id': 74, 'name': 'ball', 'isthing': 1, 'color': [0, 255, 173]},\n    {'id': 76, 'name': 'box', 'isthing': 1, 'color': [173, 255, 0]},\n    {'id': 77, 'name': 'traveling_case_or_trolley_case', 'isthing': 1, 'color': [0, 255, 153]},\n    {'id': 78, 'name': 'basket', 'isthing': 1, 'color': [255, 92, 0]},\n    {'id': 79, 'name': 'bag_or_package', 'isthing': 1, 'color': [255, 0, 255]},\n    {'id': 82, 'name': 'plate', 'isthing': 1, 'color': [255, 173, 0]},\n    {'id': 83, 'name': 'tub_or_bowl_or_pot', 'isthing': 1, 'color': [255, 0, 20]},\n    {'id': 84, 'name': 'bottle_or_cup', 'isthing': 1, 'color': [255, 184, 184]},\n    {'id': 85, 'name': 'barrel', 'isthing': 1, 'color': [0, 31, 255]},\n    {'id': 86, 'name': 'fishbowl', 'isthing': 1, 'color': [0, 255, 61]},\n    {'id': 87, 'name': 'bed', 'isthing': 1, 'color': [0, 71, 255]},\n    {'id': 88, 'name': 'pillow', 'isthing': 1, 'color': [255, 0, 204]},\n    {'id': 89, 'name': 'table_or_desk', 'isthing': 1, 'color': [0, 255, 194]},\n    {'id': 90, 'name': 'chair_or_seat', 'isthing': 1, 'color': [0, 255, 82]},\n    {'id': 91, 'name': 'bench', 'isthing': 1, 'color': [0, 10, 255]},\n    {'id': 92, 'name': 'sofa', 'isthing': 1, 'color': [0, 112, 255]},\n    {'id': 95, 'name': 'gun', 'isthing': 1, 'color': [0, 122, 255]},\n    {'id': 96, 'name': 'commode', 'isthing': 1, 'color': [0, 255, 163]},\n    {'id': 97, 'name': 'roaster', 'isthing': 1, 'color': [255, 153, 0]},\n    {'id': 99, 'name': 'refrigerator', 'isthing': 1, 'color': [255, 112, 0]},\n    {'id': 100, 'name': 'washing_machine', 'isthing': 1, 'color': [143, 255, 0]},\n    {'id': 101, 'name': 'Microwave_oven', 'isthing': 1, 'color': [82, 0, 255]},\n    {'id': 102, 'name': 'fan', 'isthing': 1, 'color': [163, 255, 0]},\n    {'id': 106, 'name': 'painting_or_poster', 'isthing': 1, 'color': [0, 255, 92]},\n    {'id': 107, 'name': 'mirror', 'isthing': 1, 'color': [184, 0, 255]},\n    {'id': 108, 'name': 'flower_pot_or_vase', 'isthing': 1, 'color': [255, 0, 31]},\n    {'id': 109, 'name': 'clock', 'isthing': 1, 'color': [0, 184, 255]},\n    {'id': 114, 'name': 'screen_or_television', 'isthing': 1, 'color': [112, 224, 255]},\n    {'id': 115, 'name': 'computer', 'isthing': 1, 'color': [70, 184, 160]},\n    {'id': 116, 'name': 'printer', 'isthing': 1, 'color': [163, 0, 255]},\n    {'id': 117, 'name': 'Mobile_phone', 'isthing': 1, 'color': [153, 0, 255]},\n    {'id': 118, 'name': 'keyboard', 'isthing': 1, 'color': [71, 255, 0]},\n    {'id': 122, 'name': 'instrument', 'isthing': 1, 'color': [0, 255, 235]},\n    {'id': 123, 'name': 'train', 'isthing': 1, 'color': [133, 255, 0]}\n]\n\nCLASSES_STUFF = [\n    {'id': 0, 'name': 'wall', 'isthing': 0, 'color': [120, 120, 120]},\n    {'id': 1, 'name': 'ceiling', 'isthing': 0, 'color': [180, 120, 120]},\n    {'id': 3, 'name': 'stair', 'isthing': 0, 'color': [80, 50, 50]},\n    {'id': 5, 'name': 'escalator', 'isthing': 0, 'color': [120, 120, 80]},\n    {'id': 6, 'name': 'Playground_slide', 'isthing': 0, 'color': [140, 140, 140]},\n    {'id': 7, 'name': 'handrail_or_fence', 'isthing': 0, 'color': [204, 5, 255]},\n    {'id': 9, 'name': 'rail', 'isthing': 0, 'color': [4, 250, 7]},\n    {'id': 11, 'name': 'pillar', 'isthing': 0, 'color': [235, 255, 7]},\n    {'id': 12, 'name': 'pole', 'isthing': 0, 'color': [150, 5, 61]},\n    {'id': 13, 'name': 'floor', 'isthing': 0, 'color': [120, 120, 70]},\n    {'id': 14, 'name': 'ground', 'isthing': 0, 'color': [8, 255, 51]},\n    {'id': 15, 'name': 'grass', 'isthing': 0, 'color': [255, 6, 82]},\n    {'id': 16, 'name': 'sand', 'isthing': 0, 'color': [143, 255, 140]},\n    {'id': 17, 'name': 'athletic_field', 'isthing': 0, 'color': [204, 255, 4]},\n    {'id': 18, 'name': 'road', 'isthing': 0, 'color': [255, 51, 7]},\n    {'id': 19, 'name': 'path', 'isthing': 0, 'color': [204, 70, 3]},\n    {'id': 20, 'name': 'crosswalk', 'isthing': 0, 'color': [0, 102, 200]},\n    {'id': 21, 'name': 'building', 'isthing': 0, 'color': [61, 230, 250]},\n    {'id': 22, 'name': 'house', 'isthing': 0, 'color': [255, 6, 51]},\n    {'id': 23, 'name': 'bridge', 'isthing': 0, 'color': [11, 102, 255]},\n    {'id': 24, 'name': 'tower', 'isthing': 0, 'color': [255, 7, 71]},\n    {'id': 25, 'name': 'windmill', 'isthing': 0, 'color': [255, 9, 224]},\n    {'id': 26, 'name': 'well_or_well_lid', 'isthing': 0, 'color': [9, 7, 230]},\n    {'id': 27, 'name': 'other_construction', 'isthing': 0, 'color': [220, 220, 220]},\n    {'id': 28, 'name': 'sky', 'isthing': 0, 'color': [255, 9, 92]},\n    {'id': 29, 'name': 'mountain', 'isthing': 0, 'color': [112, 9, 255]},\n    {'id': 30, 'name': 'stone', 'isthing': 0, 'color': [8, 255, 214]},\n    {'id': 31, 'name': 'wood', 'isthing': 0, 'color': [7, 255, 224]},\n    {'id': 32, 'name': 'ice', 'isthing': 0, 'color': [255, 184, 6]},\n    {'id': 33, 'name': 'snowfield', 'isthing': 0, 'color': [10, 255, 71]},\n    {'id': 34, 'name': 'grandstand', 'isthing': 0, 'color': [255, 41, 10]},\n    {'id': 35, 'name': 'sea', 'isthing': 0, 'color': [7, 255, 255]},\n    {'id': 36, 'name': 'river', 'isthing': 0, 'color': [224, 255, 8]},\n    {'id': 37, 'name': 'lake', 'isthing': 0, 'color': [102, 8, 255]},\n    {'id': 38, 'name': 'waterfall', 'isthing': 0, 'color': [255, 61, 6]},\n    {'id': 39, 'name': 'water', 'isthing': 0, 'color': [255, 194, 7]},\n    {'id': 40, 'name': 'billboard_or_Bulletin_Board', 'isthing': 0, 'color': [255, 122, 8]},\n    {'id': 42, 'name': 'pipeline', 'isthing': 0, 'color': [255, 8, 41]},\n    {'id': 45, 'name': 'cushion_or_carpet', 'isthing': 0, 'color': [235, 12, 255]},\n    {'id': 53, 'name': 'wheeled_machine', 'isthing': 0, 'color': [255, 224, 0]},\n    {'id': 57, 'name': 'tyre', 'isthing': 0, 'color': [0, 235, 255]},\n    {'id': 58, 'name': 'traffic_light', 'isthing': 0, 'color': [0, 173, 255]},\n    {'id': 59, 'name': 'lamp', 'isthing': 0, 'color': [31, 0, 255]},\n    {'id': 66, 'name': 'tree', 'isthing': 0, 'color': [255, 0, 0]},\n    {'id': 67, 'name': 'flower', 'isthing': 0, 'color': [255, 163, 0]},\n    {'id': 68, 'name': 'other_plant', 'isthing': 0, 'color': [255, 102, 0]},\n    {'id': 69, 'name': 'toy', 'isthing': 0, 'color': [194, 255, 0]},\n    {'id': 70, 'name': 'ball_net', 'isthing': 0, 'color': [0, 143, 255]},\n    {'id': 71, 'name': 'backboard', 'isthing': 0, 'color': [51, 255, 0]},\n    {'id': 73, 'name': 'bat', 'isthing': 0, 'color': [0, 255, 41]},\n    {'id': 75, 'name': 'cupboard_or_showcase_or_storage_rack', 'isthing': 0, 'color': [10, 0, 255]},\n    {'id': 80, 'name': 'trash_can', 'isthing': 0, 'color': [255, 0, 245]},\n    {'id': 81, 'name': 'cage', 'isthing': 0, 'color': [255, 0, 102]},\n    {'id': 93, 'name': 'shelf', 'isthing': 0, 'color': [51, 0, 255]},\n    {'id': 94, 'name': 'bathtub', 'isthing': 0, 'color': [0, 194, 255]},\n    {'id': 98, 'name': 'other_machine', 'isthing': 0, 'color': [0, 255, 10]},\n    {'id': 103, 'name': 'curtain', 'isthing': 0, 'color': [255, 235, 0]},\n    {'id': 104, 'name': 'textiles', 'isthing': 0, 'color': [8, 184, 170]},\n    {'id': 105, 'name': 'clothes', 'isthing': 0, 'color': [133, 0, 255]},\n    {'id': 110, 'name': 'book', 'isthing': 0, 'color': [0, 214, 255]},\n    {'id': 111, 'name': 'tool', 'isthing': 0, 'color': [255, 0, 112]},\n    {'id': 112, 'name': 'blackboard', 'isthing': 0, 'color': [92, 255, 0]},\n    {'id': 113, 'name': 'tissue', 'isthing': 0, 'color': [0, 224, 255]},\n    {'id': 119, 'name': 'other_electronic_product', 'isthing': 0, 'color': [255, 0, 163]},\n    {'id': 120, 'name': 'fruit', 'isthing': 0, 'color': [255, 204, 0]},\n    {'id': 121, 'name': 'food', 'isthing': 0, 'color': [255, 0, 143]}\n]\n\nNO_OBJ = 0\nNO_OBJ_HB = 255\nDIVISOR_PAN = 100\nDIVISOR_NEW = 1000\nNUM_THING = 58\nNUM_STUFF = 66\nTHING_B_STUFF = False\n\n\ndef vip2hb(pan_map):\n    assert not THING_B_STUFF, \"VIPSeg only supports stuff -> thing\"\n    pan_new = - np.ones_like(pan_map)\n    vip2hb_thing = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_THING)}\n    vip2hb_stuff = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_STUFF)}\n    for idx in np.unique(pan_map):\n        if idx == NO_OBJ or idx == 200:\n            pan_new[pan_map == idx] = NO_OBJ_HB * DIVISOR_NEW\n        elif idx > 128:\n            cls_id = idx // DIVISOR_PAN\n            cls_new_id = vip2hb_thing[cls_id]\n            inst_id = idx % DIVISOR_PAN\n            # since stuff first -> thing the second\n            cls_new_id += NUM_STUFF\n            pan_new[pan_map == idx] = cls_new_id * DIVISOR_NEW + inst_id + 1\n        else:\n            pan_new[pan_map == idx] = vip2hb_stuff[idx] * DIVISOR_NEW\n    assert -1. not in np.unique(pan_new)\n    return pan_new\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Evaluation of DSTQ')\n    parser.add_argument('result_path')\n    parser.add_argument('--gt-path', default='data/kitti-step')\n    parser.add_argument('--split', default='val')\n    parser.add_argument(\n        '--depth',\n        action='store_true',\n        help='eval depth')\n    parser.add_argument('--nproc', default=1, type=int, help='number of process')\n    args = parser.parse_args()\n    return args\n\n\ndef updater(pred_ins_name,\n            pred_cls_name,\n            pred_dep_name,\n            gt_pan_seq_name,\n            gt_dep_seq_name,\n            updater_obj,\n            seq_id):\n    pred_ins = mmcv.imread(pred_ins_name, flag='unchanged').astype(np.int32)\n    pred_cls = mmcv.imread(pred_cls_name, flag='unchanged').astype(np.int32)\n    pred_dep = mmcv.imread(pred_dep_name, flag='unchanged').astype(np.float32) if pred_dep_name is not None else None\n\n    gt_pan = mmcv.imread(gt_pan_seq_name,  flag='unchanged').astype(np.int64)\n    gt_pan = vip2hb(gt_pan)\n\n    gt_cls = gt_pan // DIVISOR_NEW\n    gt_ins = gt_pan % DIVISOR_NEW\n    gt_dep = mmcv.imread(gt_dep_seq_name, flag='unchanged').astype(np.float32) if gt_dep_seq_name is not None else None\n    if pred_dep is not None:\n        pred_dep = F.interpolate(torch.from_numpy(pred_dep)[None][None], size=gt_dep.shape)[0][0].numpy()\n\n    valid_mask_seg = gt_cls != NO_OBJ_HB\n\n    pred_masked_ps = pred_cls[valid_mask_seg] * (2 ** 16) + pred_ins[valid_mask_seg]\n    gt_masked_ps = gt_cls[valid_mask_seg] * (2 ** 16) + gt_ins[valid_mask_seg]\n\n    if pred_dep_name is not None:\n        valid_mask_dep = gt_dep > 0.\n\n        pred_masked_depth = pred_dep[valid_mask_dep]\n        gt_masked_depth = gt_dep[valid_mask_dep]\n\n        updater_obj.update_state(gt_masked_ps, pred_masked_ps, gt_masked_depth, pred_masked_depth, seq_id)\n    else:\n        updater_obj.update_state(gt_masked_ps, pred_masked_ps, seq_id)\n\n\ndef eval_dstq(result_dir, gt_dir, with_depth=True):\n    if with_depth:\n        dstq_obj = DSTQuality(\n            num_classes=len(CLASSES),\n            things_list=list(range(66, 124)),\n            ignore_label=NO_OBJ_HB,\n            label_bit_shift=16,\n            offset=2 ** 16 * 256,\n            depth_threshold=(1.25,),\n        )\n    else:\n        dstq_obj = STQuality(\n            num_classes=len(CLASSES),\n            things_list=list(range(66, 124)),\n            ignore_label=NO_OBJ_HB,\n            label_bit_shift=16,\n            offset=2 ** 16 * 256,\n        )\n    ann_folders = mmcv.list_from_file(os.path.join(gt_dir, \"{}.txt\".format(split)),\n                                      prefix=os.path.join(gt_dir, 'panomasks') + '/')\n    seq_ids = np.arange(0, len(ann_folders)).tolist()\n\n    for seq_id in seq_ids:\n\n        gt_names = list(mmcv.scandir(ann_folders[seq_id]))\n        gt_pan_names = sorted(list(filter(lambda x: '.png' in x, gt_names)))\n        if with_depth:\n            gt_dep_names = sorted(list(filter(lambda x: 'depth' in x, gt_names)))\n        else:\n            gt_dep_names = [None] * len(gt_pan_names)\n        pred_name_panoptic = list(mmcv.scandir(os.path.join(result_dir, 'panoptic', str(seq_id))))\n        pred_ins_names = sorted(list(filter(lambda x: 'ins' in x, pred_name_panoptic)))\n        pred_cls_names = sorted(list(filter(lambda x: 'cat' in x, pred_name_panoptic)))\n        if len(gt_pan_names) != len(pred_ins_names):\n            print(\"Error when seq_id is {}. But cal existing seqs.\".format(seq_id))\n            break\n        if with_depth:\n            pred_name_depth = list(mmcv.scandir(os.path.join(result_dir, 'depth', str(seq_id))))\n            pred_dep_names = sorted(pred_name_depth)\n        else:\n            pred_dep_names = [None] * len(pred_ins_names)\n        prog_bar = ProgressBar(len(pred_ins_names))\n        for pred_ins_name, pred_cls_name, pred_dep_name, gt_pan_seq_name, gt_dep_seq_name in zip(\n                pred_ins_names, pred_cls_names, pred_dep_names, gt_pan_names, gt_dep_names\n        ):\n            prog_bar.update()\n            updater(\n                os.path.join(result_dir, 'panoptic', str(seq_id), pred_ins_name),\n                os.path.join(result_dir, 'panoptic', str(seq_id), pred_cls_name),\n                os.path.join(result_dir, 'depth', str(seq_id), pred_dep_name) if pred_dep_name is not None else None,\n                os.path.join(ann_folders[seq_id], gt_pan_seq_name),\n                os.path.join(ann_folders[seq_id], gt_dep_seq_name) if gt_dep_seq_name is not None else None,\n                dstq_obj,\n                seq_id\n            )\n    result = dstq_obj.result()\n    print(result)\n\n# usage python eval_dstq_vipseg.py /opt/data/results/test --gt-path /opt/data/VIPSeg\nif __name__ == '__main__':\n    args = parse_args()\n    result_path = args.result_path\n    gt_path = args.gt_path\n    split = args.split\n    eval_dstq(result_path, gt_path, args.depth)\n"
  },
  {
    "path": "tools/eval_dvpq_step.py",
    "content": "import numpy as np\nfrom PIL import Image\nimport six\nimport os\nimport multiprocessing as mp\nimport argparse\n\nparser = argparse.ArgumentParser(description='')\nparser.add_argument('result_path')\nparser.add_argument('--eval_frames', type=int, default=1)\nparser.add_argument('--depth_thres', type=float, default=0)\nargs = parser.parse_args()\n\neval_frames = args.eval_frames\npred_dir_all = os.path.join(args.result_path, 'panoptic')\ndepth_dir_all = os.path.join(args.result_path, 'depth')\ngt_dir = 'data/kitti-step/video_sequence/val/'\ndepth_thres = args.depth_thres\n\n\ndef vpq_eval(element):\n    pred_ids, gt_ids = element\n    max_ins = 2 ** 16\n    ign_id = 255\n    offset = 2 ** 30\n    num_cat = 20\n\n    iou_per_class = np.zeros(num_cat, dtype=np.float64)\n    tp_per_class = np.zeros(num_cat, dtype=np.float64)\n    fn_per_class = np.zeros(num_cat, dtype=np.float64)\n    fp_per_class = np.zeros(num_cat, dtype=np.float64)\n\n    def _ids_to_counts(id_array):\n        ids, counts = np.unique(id_array, return_counts=True)\n        return dict(six.moves.zip(ids, counts))\n\n    pred_areas = _ids_to_counts(pred_ids)\n    gt_areas = _ids_to_counts(gt_ids)\n\n    void_id = ign_id * max_ins\n    ign_ids = {\n        gt_id for gt_id in six.iterkeys(gt_areas)\n        if (gt_id // max_ins) == ign_id\n    }\n\n    int_ids = gt_ids.astype(np.int64) * offset + pred_ids.astype(np.int64)\n    int_areas = _ids_to_counts(int_ids)\n\n    def prediction_void_overlap(pred_id):\n        void_int_id = void_id * offset + pred_id\n        return int_areas.get(void_int_id, 0)\n\n    def prediction_ignored_overlap(pred_id):\n        total_ignored_overlap = 0\n        for _ign_id in ign_ids:\n            int_id = _ign_id * offset + pred_id\n            total_ignored_overlap += int_areas.get(int_id, 0)\n        return total_ignored_overlap\n\n    gt_matched = set()\n    pred_matched = set()\n\n    for int_id, int_area in six.iteritems(int_areas):\n        gt_id = int(int_id // offset)\n        gt_cat = int(gt_id // max_ins)\n        pred_id = int(int_id % offset)\n        pred_cat = int(pred_id // max_ins)\n        if gt_cat != pred_cat:\n            continue\n        union = (\n            gt_areas[gt_id] + pred_areas[pred_id] - int_area -\n            prediction_void_overlap(pred_id)\n        )\n        iou = int_area / union\n        if iou > 0.5:\n            tp_per_class[gt_cat] += 1\n            iou_per_class[gt_cat] += iou\n            gt_matched.add(gt_id)\n            pred_matched.add(pred_id)\n\n    for gt_id in six.iterkeys(gt_areas):\n        if gt_id in gt_matched:\n            continue\n        cat_id = gt_id // max_ins\n        if cat_id == ign_id:\n            continue\n        fn_per_class[cat_id] += 1\n\n    for pred_id in six.iterkeys(pred_areas):\n        if pred_id in pred_matched:\n            continue\n        if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5:\n            continue\n        cat = pred_id // max_ins\n        fp_per_class[cat] += 1\n\n    return (iou_per_class, tp_per_class, fn_per_class, fp_per_class)\n\n\ndef eval(element):\n    max_ins = 2 ** 16\n\n    pred_cat, pred_ins, gts, depth_preds, depth_gts = element\n    pred_cat = [np.array(Image.open(image)) for image in pred_cat]\n    pred_ins = [np.array(Image.open(image)) for image in pred_ins]\n    pred_cat = np.concatenate(pred_cat, axis=1)\n    pred_ins = np.concatenate(pred_ins, axis=1)\n    pred = pred_cat.astype(np.int32) * max_ins + pred_ins.astype(np.int32)\n\n    gts_pan = [np.array(Image.open(image)) for image in gts]\n    gts = [gt_pan[..., 0].astype(np.int32) * max_ins +\n           gt_pan[..., 1].astype(np.int32) * 256 + gt_pan[..., 2].astype(np.int32)\n           for gt_pan in gts_pan]\n\n    abs_rel = 0.\n    if depth_thres > 0:\n        depth_preds = [np.array(Image.open(name)) for name in depth_preds]\n        depth_gts = [np.array(Image.open(name)) for name in depth_gts]\n        depth_preds = np.concatenate(depth_preds, axis=1)\n        depth_gts = np.concatenate(depth_gts, axis=1)\n        depth_mask = depth_gts > 0\n        abs_rel = np.mean(\n            np.abs(\n                depth_preds[depth_mask] -\n                depth_gts[depth_mask]) /\n            depth_gts[depth_mask])\n        pred_in_mask = pred[:, :depth_preds.shape[1]]\n        pred_in_depth_mask = pred_in_mask[depth_mask]\n        ignored_pred_mask = (\n            np.abs(\n                depth_preds[depth_mask] -\n                depth_gts[depth_mask]) /\n            depth_gts[depth_mask]) > depth_thres\n        pred_in_depth_mask[ignored_pred_mask] = 19 * max_ins\n        pred_in_mask[depth_mask] = pred_in_depth_mask\n        pred[:, :depth_preds.shape[1]] = pred_in_mask\n\n    gt = np.concatenate(gts, axis=1)\n    result = vpq_eval([pred, gt])\n\n    return result + (abs_rel, )\n\n\ndef main():\n    gt_names_all = os.scandir(gt_dir)\n    gt_names_all = [name.name for name in gt_names_all if 'panoptic' in name.name]\n    gt_names_all = [os.path.join(gt_dir, name) for name in gt_names_all]\n    gt_names_all = sorted(gt_names_all)\n\n    if args.depth_thres > 0:\n        depth_gt_names_all = os.scandir(gt_dir)\n        depth_gt_names_all = [\n            name.name for name in depth_gt_names_all if 'depth' in name.name]\n        depth_gt_names_all = [os.path.join(gt_dir, name) for name in depth_gt_names_all]\n        depth_gt_names_all = sorted(depth_gt_names_all)\n\n    iou_per_class_all = []\n    tp_per_class_all = []\n    fn_per_class_all = []\n    fp_per_class_all = []\n\n    things_index = np.zeros((19,)).astype(bool)\n    things_index[11] = True\n    things_index[13] = True\n\n    for i in [2, 6, 7, 8, 10, 13, 14, 16, 18]:\n        if args.depth_thres > 0:\n            depth_dir = os.path.join(depth_dir_all, str(i))\n            depth_pred_names = os.scandir(depth_dir)\n            depth_pred_names = [name.name for name in depth_pred_names]\n            depth_pred_names = [os.path.join(depth_dir, name)\n                                for name in depth_pred_names]\n            depth_pred_names = sorted(depth_pred_names)\n\n        pred_dir = os.path.join(pred_dir_all, str(i))\n        pred_names = os.scandir(pred_dir)\n        pred_names = [os.path.join(pred_dir, name.name) for name in pred_names]\n        cat_pred_names = [name for name in pred_names if name.endswith('cat.png')]\n        ins_pred_names = [name for name in pred_names if name.endswith('ins.png')]\n        cat_pred_names = sorted(cat_pred_names)\n        ins_pred_names = sorted(ins_pred_names)\n\n        all_lst = []\n        gt_names = sorted(list(filter(lambda x: os.path.basename(x).startswith('{:06d}'.format(i)), gt_names_all)))\n        if args.depth_thres > 0:\n            depth_gt_names = sorted(list(filter(lambda x: os.path.basename(x).startswith('{:06d}'.format(i)), depth_gt_names_all)))\n        for i in range(len(cat_pred_names) - eval_frames + 1):\n            all_lst.append([cat_pred_names[i: i + eval_frames],\n                            ins_pred_names[i: i + eval_frames],\n                            gt_names[i: i + eval_frames],\n                            depth_pred_names[i: i + eval_frames] if args.depth_thres > 0 else None,\n                            depth_gt_names[i: i + eval_frames] if args.depth_thres > 0 else None\n                            ])\n\n        N = mp.cpu_count() // 2\n        with mp.Pool(processes=N) as p:\n            results = p.map(eval, all_lst)\n        iou_per_class = np.stack([result[0] for result in results])\n        iou_per_class_all.append(iou_per_class)\n        tp_per_class = np.stack([result[1] for result in results])\n        tp_per_class_all.append(tp_per_class)\n        fn_per_class = np.stack([result[2] for result in results])\n        fn_per_class_all.append(fn_per_class)\n        fp_per_class = np.stack([result[3] for result in results])\n        fp_per_class_all.append(fp_per_class)\n        # abs_rel = np.stack([result[4] for result in results]).mean(axis=0)\n        epsilon = 1e-10\n        iou_per_class = iou_per_class.sum(axis=0)[:19]\n        tp_per_class = tp_per_class.sum(axis=0)[:19]\n        fn_per_class = fn_per_class.sum(axis=0)[:19]\n        fp_per_class = fp_per_class.sum(axis=0)[:19]\n        sq = iou_per_class / (tp_per_class + epsilon)\n        rq = tp_per_class / (tp_per_class + 0.5 *\n                             fn_per_class + 0.5 * fp_per_class + epsilon)\n        pq = sq * rq\n        spq = pq[np.logical_not(things_index)]\n        tpq = pq[things_index]\n        print(\n            r'{:.1f} {:.1f} {:.1f}'.format(\n                pq.mean() * 100,\n                tpq.mean() * 100,\n                spq.mean() * 100))\n\n    print(\"----------------final-----------------\")\n    iou_per_class_all = np.concatenate(iou_per_class_all, axis=0).sum(axis=0)[:19]\n    tp_per_class_all = np.concatenate(tp_per_class_all, axis=0).sum(axis=0)[:19]\n    fn_per_class_all = np.concatenate(fn_per_class_all, axis=0).sum(axis=0)[:19]\n    fp_per_class_all = np.concatenate(fp_per_class_all, axis=0).sum(axis=0)[:19]\n\n    sq = iou_per_class_all / (tp_per_class_all + epsilon)\n    rq = tp_per_class_all / (tp_per_class_all + 0.5 *\n                         fn_per_class_all + 0.5 * fp_per_class_all + epsilon)\n    pq = sq * rq\n    spq = pq[np.logical_not(things_index)]\n    tpq = pq[things_index]\n    print(\n        r'{:.1f} {:.1f} {:.1f}'.format(\n            pq.mean() * 100,\n            tpq.mean() * 100,\n            spq.mean() * 100))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/eval_dvpq_vipseg.py",
    "content": "import argparse\nimport os\n\nimport mmcv\nimport numpy as np\nimport six\nimport multiprocessing as mp\n\nCLASSES = [\n    {\"id\": 0, \"name\": \"wall\", \"isthing\": 0, \"color\": [120, 120, 120]},\n    {\"id\": 1, \"name\": \"ceiling\", \"isthing\": 0, \"color\": [180, 120, 120]},\n    {\"id\": 2, \"name\": \"door\", \"isthing\": 1, \"color\": [6, 230, 230]},\n    {\"id\": 3, \"name\": \"stair\", \"isthing\": 0, \"color\": [80, 50, 50]},\n    {\"id\": 4, \"name\": \"ladder\", \"isthing\": 1, \"color\": [4, 200, 3]},\n    {\"id\": 5, \"name\": \"escalator\", \"isthing\": 0, \"color\": [120, 120, 80]},\n    {\"id\": 6, \"name\": \"Playground_slide\", \"isthing\": 0, \"color\": [140, 140, 140]},\n    {\"id\": 7, \"name\": \"handrail_or_fence\", \"isthing\": 0, \"color\": [204, 5, 255]},\n    {\"id\": 8, \"name\": \"window\", \"isthing\": 1, \"color\": [230, 230, 230]},\n    {\"id\": 9, \"name\": \"rail\", \"isthing\": 0, \"color\": [4, 250, 7]},\n    {\"id\": 10, \"name\": \"goal\", \"isthing\": 1, \"color\": [224, 5, 255]},\n    {\"id\": 11, \"name\": \"pillar\", \"isthing\": 0, \"color\": [235, 255, 7]},\n    {\"id\": 12, \"name\": \"pole\", \"isthing\": 0, \"color\": [150, 5, 61]},\n    {\"id\": 13, \"name\": \"floor\", \"isthing\": 0, \"color\": [120, 120, 70]},\n    {\"id\": 14, \"name\": \"ground\", \"isthing\": 0, \"color\": [8, 255, 51]},\n    {\"id\": 15, \"name\": \"grass\", \"isthing\": 0, \"color\": [255, 6, 82]},\n    {\"id\": 16, \"name\": \"sand\", \"isthing\": 0, \"color\": [143, 255, 140]},\n    {\"id\": 17, \"name\": \"athletic_field\", \"isthing\": 0, \"color\": [204, 255, 4]},\n    {\"id\": 18, \"name\": \"road\", \"isthing\": 0, \"color\": [255, 51, 7]},\n    {\"id\": 19, \"name\": \"path\", \"isthing\": 0, \"color\": [204, 70, 3]},\n    {\"id\": 20, \"name\": \"crosswalk\", \"isthing\": 0, \"color\": [0, 102, 200]},\n    {\"id\": 21, \"name\": \"building\", \"isthing\": 0, \"color\": [61, 230, 250]},\n    {\"id\": 22, \"name\": \"house\", \"isthing\": 0, \"color\": [255, 6, 51]},\n    {\"id\": 23, \"name\": \"bridge\", \"isthing\": 0, \"color\": [11, 102, 255]},\n    {\"id\": 24, \"name\": \"tower\", \"isthing\": 0, \"color\": [255, 7, 71]},\n    {\"id\": 25, \"name\": \"windmill\", \"isthing\": 0, \"color\": [255, 9, 224]},\n    {\"id\": 26, \"name\": \"well_or_well_lid\", \"isthing\": 0, \"color\": [9, 7, 230]},\n    {\"id\": 27, \"name\": \"other_construction\", \"isthing\": 0, \"color\": [220, 220, 220]},\n    {\"id\": 28, \"name\": \"sky\", \"isthing\": 0, \"color\": [255, 9, 92]},\n    {\"id\": 29, \"name\": \"mountain\", \"isthing\": 0, \"color\": [112, 9, 255]},\n    {\"id\": 30, \"name\": \"stone\", \"isthing\": 0, \"color\": [8, 255, 214]},\n    {\"id\": 31, \"name\": \"wood\", \"isthing\": 0, \"color\": [7, 255, 224]},\n    {\"id\": 32, \"name\": \"ice\", \"isthing\": 0, \"color\": [255, 184, 6]},\n    {\"id\": 33, \"name\": \"snowfield\", \"isthing\": 0, \"color\": [10, 255, 71]},\n    {\"id\": 34, \"name\": \"grandstand\", \"isthing\": 0, \"color\": [255, 41, 10]},\n    {\"id\": 35, \"name\": \"sea\", \"isthing\": 0, \"color\": [7, 255, 255]},\n    {\"id\": 36, \"name\": \"river\", \"isthing\": 0, \"color\": [224, 255, 8]},\n    {\"id\": 37, \"name\": \"lake\", \"isthing\": 0, \"color\": [102, 8, 255]},\n    {\"id\": 38, \"name\": \"waterfall\", \"isthing\": 0, \"color\": [255, 61, 6]},\n    {\"id\": 39, \"name\": \"water\", \"isthing\": 0, \"color\": [255, 194, 7]},\n    {\"id\": 40, \"name\": \"billboard_or_Bulletin_Board\", \"isthing\": 0, \"color\": [255, 122, 8]},\n    {\"id\": 41, \"name\": \"sculpture\", \"isthing\": 1, \"color\": [0, 255, 20]},\n    {\"id\": 42, \"name\": \"pipeline\", \"isthing\": 0, \"color\": [255, 8, 41]},\n    {\"id\": 43, \"name\": \"flag\", \"isthing\": 1, \"color\": [255, 5, 153]},\n    {\"id\": 44, \"name\": \"parasol_or_umbrella\", \"isthing\": 1, \"color\": [6, 51, 255]},\n    {\"id\": 45, \"name\": \"cushion_or_carpet\", \"isthing\": 0, \"color\": [235, 12, 255]},\n    {\"id\": 46, \"name\": \"tent\", \"isthing\": 1, \"color\": [160, 150, 20]},\n    {\"id\": 47, \"name\": \"roadblock\", \"isthing\": 1, \"color\": [0, 163, 255]},\n    {\"id\": 48, \"name\": \"car\", \"isthing\": 1, \"color\": [140, 140, 140]},\n    {\"id\": 49, \"name\": \"bus\", \"isthing\": 1, \"color\": [250, 10, 15]},\n    {\"id\": 50, \"name\": \"truck\", \"isthing\": 1, \"color\": [20, 255, 0]},\n    {\"id\": 51, \"name\": \"bicycle\", \"isthing\": 1, \"color\": [31, 255, 0]},\n    {\"id\": 52, \"name\": \"motorcycle\", \"isthing\": 1, \"color\": [255, 31, 0]},\n    {\"id\": 53, \"name\": \"wheeled_machine\", \"isthing\": 0, \"color\": [255, 224, 0]},\n    {\"id\": 54, \"name\": \"ship_or_boat\", \"isthing\": 1, \"color\": [153, 255, 0]},\n    {\"id\": 55, \"name\": \"raft\", \"isthing\": 1, \"color\": [0, 0, 255]},\n    {\"id\": 56, \"name\": \"airplane\", \"isthing\": 1, \"color\": [255, 71, 0]},\n    {\"id\": 57, \"name\": \"tyre\", \"isthing\": 0, \"color\": [0, 235, 255]},\n    {\"id\": 58, \"name\": \"traffic_light\", \"isthing\": 0, \"color\": [0, 173, 255]},\n    {\"id\": 59, \"name\": \"lamp\", \"isthing\": 0, \"color\": [31, 0, 255]},\n    {\"id\": 60, \"name\": \"person\", \"isthing\": 1, \"color\": [11, 200, 200]},\n    {\"id\": 61, \"name\": \"cat\", \"isthing\": 1, \"color\": [255, 82, 0]},\n    {\"id\": 62, \"name\": \"dog\", \"isthing\": 1, \"color\": [0, 255, 245]},\n    {\"id\": 63, \"name\": \"horse\", \"isthing\": 1, \"color\": [0, 61, 255]},\n    {\"id\": 64, \"name\": \"cattle\", \"isthing\": 1, \"color\": [0, 255, 112]},\n    {\"id\": 65, \"name\": \"other_animal\", \"isthing\": 1, \"color\": [0, 255, 133]},\n    {\"id\": 66, \"name\": \"tree\", \"isthing\": 0, \"color\": [255, 0, 0]},\n    {\"id\": 67, \"name\": \"flower\", \"isthing\": 0, \"color\": [255, 163, 0]},\n    {\"id\": 68, \"name\": \"other_plant\", \"isthing\": 0, \"color\": [255, 102, 0]},\n    {\"id\": 69, \"name\": \"toy\", \"isthing\": 0, \"color\": [194, 255, 0]},\n    {\"id\": 70, \"name\": \"ball_net\", \"isthing\": 0, \"color\": [0, 143, 255]},\n    {\"id\": 71, \"name\": \"backboard\", \"isthing\": 0, \"color\": [51, 255, 0]},\n    {\"id\": 72, \"name\": \"skateboard\", \"isthing\": 1, \"color\": [0, 82, 255]},\n    {\"id\": 73, \"name\": \"bat\", \"isthing\": 0, \"color\": [0, 255, 41]},\n    {\"id\": 74, \"name\": \"ball\", \"isthing\": 1, \"color\": [0, 255, 173]},\n    {\"id\": 75, \"name\": \"cupboard_or_showcase_or_storage_rack\", \"isthing\": 0, \"color\": [10, 0, 255]},\n    {\"id\": 76, \"name\": \"box\", \"isthing\": 1, \"color\": [173, 255, 0]},\n    {\"id\": 77, \"name\": \"traveling_case_or_trolley_case\", \"isthing\": 1, \"color\": [0, 255, 153]},\n    {\"id\": 78, \"name\": \"basket\", \"isthing\": 1, \"color\": [255, 92, 0]},\n    {\"id\": 79, \"name\": \"bag_or_package\", \"isthing\": 1, \"color\": [255, 0, 255]},\n    {\"id\": 80, \"name\": \"trash_can\", \"isthing\": 0, \"color\": [255, 0, 245]},\n    {\"id\": 81, \"name\": \"cage\", \"isthing\": 0, \"color\": [255, 0, 102]},\n    {\"id\": 82, \"name\": \"plate\", \"isthing\": 1, \"color\": [255, 173, 0]},\n    {\"id\": 83, \"name\": \"tub_or_bowl_or_pot\", \"isthing\": 1, \"color\": [255, 0, 20]},\n    {\"id\": 84, \"name\": \"bottle_or_cup\", \"isthing\": 1, \"color\": [255, 184, 184]},\n    {\"id\": 85, \"name\": \"barrel\", \"isthing\": 1, \"color\": [0, 31, 255]},\n    {\"id\": 86, \"name\": \"fishbowl\", \"isthing\": 1, \"color\": [0, 255, 61]},\n    {\"id\": 87, \"name\": \"bed\", \"isthing\": 1, \"color\": [0, 71, 255]},\n    {\"id\": 88, \"name\": \"pillow\", \"isthing\": 1, \"color\": [255, 0, 204]},\n    {\"id\": 89, \"name\": \"table_or_desk\", \"isthing\": 1, \"color\": [0, 255, 194]},\n    {\"id\": 90, \"name\": \"chair_or_seat\", \"isthing\": 1, \"color\": [0, 255, 82]},\n    {\"id\": 91, \"name\": \"bench\", \"isthing\": 1, \"color\": [0, 10, 255]},\n    {\"id\": 92, \"name\": \"sofa\", \"isthing\": 1, \"color\": [0, 112, 255]},\n    {\"id\": 93, \"name\": \"shelf\", \"isthing\": 0, \"color\": [51, 0, 255]},\n    {\"id\": 94, \"name\": \"bathtub\", \"isthing\": 0, \"color\": [0, 194, 255]},\n    {\"id\": 95, \"name\": \"gun\", \"isthing\": 1, \"color\": [0, 122, 255]},\n    {\"id\": 96, \"name\": \"commode\", \"isthing\": 1, \"color\": [0, 255, 163]},\n    {\"id\": 97, \"name\": \"roaster\", \"isthing\": 1, \"color\": [255, 153, 0]},\n    {\"id\": 98, \"name\": \"other_machine\", \"isthing\": 0, \"color\": [0, 255, 10]},\n    {\"id\": 99, \"name\": \"refrigerator\", \"isthing\": 1, \"color\": [255, 112, 0]},\n    {\"id\": 100, \"name\": \"washing_machine\", \"isthing\": 1, \"color\": [143, 255, 0]},\n    {\"id\": 101, \"name\": \"Microwave_oven\", \"isthing\": 1, \"color\": [82, 0, 255]},\n    {\"id\": 102, \"name\": \"fan\", \"isthing\": 1, \"color\": [163, 255, 0]},\n    {\"id\": 103, \"name\": \"curtain\", \"isthing\": 0, \"color\": [255, 235, 0]},\n    {\"id\": 104, \"name\": \"textiles\", \"isthing\": 0, \"color\": [8, 184, 170]},\n    {\"id\": 105, \"name\": \"clothes\", \"isthing\": 0, \"color\": [133, 0, 255]},\n    {\"id\": 106, \"name\": \"painting_or_poster\", \"isthing\": 1, \"color\": [0, 255, 92]},\n    {\"id\": 107, \"name\": \"mirror\", \"isthing\": 1, \"color\": [184, 0, 255]},\n    {\"id\": 108, \"name\": \"flower_pot_or_vase\", \"isthing\": 1, \"color\": [255, 0, 31]},\n    {\"id\": 109, \"name\": \"clock\", \"isthing\": 1, \"color\": [0, 184, 255]},\n    {\"id\": 110, \"name\": \"book\", \"isthing\": 0, \"color\": [0, 214, 255]},\n    {\"id\": 111, \"name\": \"tool\", \"isthing\": 0, \"color\": [255, 0, 112]},\n    {\"id\": 112, \"name\": \"blackboard\", \"isthing\": 0, \"color\": [92, 255, 0]},\n    {\"id\": 113, \"name\": \"tissue\", \"isthing\": 0, \"color\": [0, 224, 255]},\n    {\"id\": 114, \"name\": \"screen_or_television\", \"isthing\": 1, \"color\": [112, 224, 255]},\n    {\"id\": 115, \"name\": \"computer\", \"isthing\": 1, \"color\": [70, 184, 160]},\n    {\"id\": 116, \"name\": \"printer\", \"isthing\": 1, \"color\": [163, 0, 255]},\n    {\"id\": 117, \"name\": \"Mobile_phone\", \"isthing\": 1, \"color\": [153, 0, 255]},\n    {\"id\": 118, \"name\": \"keyboard\", \"isthing\": 1, \"color\": [71, 255, 0]},\n    {\"id\": 119, \"name\": \"other_electronic_product\", \"isthing\": 0, \"color\": [255, 0, 163]},\n    {\"id\": 120, \"name\": \"fruit\", \"isthing\": 0, \"color\": [255, 204, 0]},\n    {\"id\": 121, \"name\": \"food\", \"isthing\": 0, \"color\": [255, 0, 143]},\n    {\"id\": 122, \"name\": \"instrument\", \"isthing\": 1, \"color\": [0, 255, 235]},\n    {\"id\": 123, \"name\": \"train\", \"isthing\": 1, \"color\": [133, 255, 0]}\n]\n\nCLASSES_THING = [\n    {'id': 2, 'name': 'door', 'isthing': 1, 'color': [6, 230, 230]},\n    {'id': 4, 'name': 'ladder', 'isthing': 1, 'color': [4, 200, 3]},\n    {'id': 8, 'name': 'window', 'isthing': 1, 'color': [230, 230, 230]},\n    {'id': 10, 'name': 'goal', 'isthing': 1, 'color': [224, 5, 255]},\n    {'id': 41, 'name': 'sculpture', 'isthing': 1, 'color': [0, 255, 20]},\n    {'id': 43, 'name': 'flag', 'isthing': 1, 'color': [255, 5, 153]},\n    {'id': 44, 'name': 'parasol_or_umbrella', 'isthing': 1, 'color': [6, 51, 255]},\n    {'id': 46, 'name': 'tent', 'isthing': 1, 'color': [160, 150, 20]},\n    {'id': 47, 'name': 'roadblock', 'isthing': 1, 'color': [0, 163, 255]},\n    {'id': 48, 'name': 'car', 'isthing': 1, 'color': [140, 140, 140]},\n    {'id': 49, 'name': 'bus', 'isthing': 1, 'color': [250, 10, 15]},\n    {'id': 50, 'name': 'truck', 'isthing': 1, 'color': [20, 255, 0]},\n    {'id': 51, 'name': 'bicycle', 'isthing': 1, 'color': [31, 255, 0]},\n    {'id': 52, 'name': 'motorcycle', 'isthing': 1, 'color': [255, 31, 0]},\n    {'id': 54, 'name': 'ship_or_boat', 'isthing': 1, 'color': [153, 255, 0]},\n    {'id': 55, 'name': 'raft', 'isthing': 1, 'color': [0, 0, 255]},\n    {'id': 56, 'name': 'airplane', 'isthing': 1, 'color': [255, 71, 0]},\n    {'id': 60, 'name': 'person', 'isthing': 1, 'color': [11, 200, 200]},\n    {'id': 61, 'name': 'cat', 'isthing': 1, 'color': [255, 82, 0]},\n    {'id': 62, 'name': 'dog', 'isthing': 1, 'color': [0, 255, 245]},\n    {'id': 63, 'name': 'horse', 'isthing': 1, 'color': [0, 61, 255]},\n    {'id': 64, 'name': 'cattle', 'isthing': 1, 'color': [0, 255, 112]},\n    {'id': 65, 'name': 'other_animal', 'isthing': 1, 'color': [0, 255, 133]},\n    {'id': 72, 'name': 'skateboard', 'isthing': 1, 'color': [0, 82, 255]},\n    {'id': 74, 'name': 'ball', 'isthing': 1, 'color': [0, 255, 173]},\n    {'id': 76, 'name': 'box', 'isthing': 1, 'color': [173, 255, 0]},\n    {'id': 77, 'name': 'traveling_case_or_trolley_case', 'isthing': 1, 'color': [0, 255, 153]},\n    {'id': 78, 'name': 'basket', 'isthing': 1, 'color': [255, 92, 0]},\n    {'id': 79, 'name': 'bag_or_package', 'isthing': 1, 'color': [255, 0, 255]},\n    {'id': 82, 'name': 'plate', 'isthing': 1, 'color': [255, 173, 0]},\n    {'id': 83, 'name': 'tub_or_bowl_or_pot', 'isthing': 1, 'color': [255, 0, 20]},\n    {'id': 84, 'name': 'bottle_or_cup', 'isthing': 1, 'color': [255, 184, 184]},\n    {'id': 85, 'name': 'barrel', 'isthing': 1, 'color': [0, 31, 255]},\n    {'id': 86, 'name': 'fishbowl', 'isthing': 1, 'color': [0, 255, 61]},\n    {'id': 87, 'name': 'bed', 'isthing': 1, 'color': [0, 71, 255]},\n    {'id': 88, 'name': 'pillow', 'isthing': 1, 'color': [255, 0, 204]},\n    {'id': 89, 'name': 'table_or_desk', 'isthing': 1, 'color': [0, 255, 194]},\n    {'id': 90, 'name': 'chair_or_seat', 'isthing': 1, 'color': [0, 255, 82]},\n    {'id': 91, 'name': 'bench', 'isthing': 1, 'color': [0, 10, 255]},\n    {'id': 92, 'name': 'sofa', 'isthing': 1, 'color': [0, 112, 255]},\n    {'id': 95, 'name': 'gun', 'isthing': 1, 'color': [0, 122, 255]},\n    {'id': 96, 'name': 'commode', 'isthing': 1, 'color': [0, 255, 163]},\n    {'id': 97, 'name': 'roaster', 'isthing': 1, 'color': [255, 153, 0]},\n    {'id': 99, 'name': 'refrigerator', 'isthing': 1, 'color': [255, 112, 0]},\n    {'id': 100, 'name': 'washing_machine', 'isthing': 1, 'color': [143, 255, 0]},\n    {'id': 101, 'name': 'Microwave_oven', 'isthing': 1, 'color': [82, 0, 255]},\n    {'id': 102, 'name': 'fan', 'isthing': 1, 'color': [163, 255, 0]},\n    {'id': 106, 'name': 'painting_or_poster', 'isthing': 1, 'color': [0, 255, 92]},\n    {'id': 107, 'name': 'mirror', 'isthing': 1, 'color': [184, 0, 255]},\n    {'id': 108, 'name': 'flower_pot_or_vase', 'isthing': 1, 'color': [255, 0, 31]},\n    {'id': 109, 'name': 'clock', 'isthing': 1, 'color': [0, 184, 255]},\n    {'id': 114, 'name': 'screen_or_television', 'isthing': 1, 'color': [112, 224, 255]},\n    {'id': 115, 'name': 'computer', 'isthing': 1, 'color': [70, 184, 160]},\n    {'id': 116, 'name': 'printer', 'isthing': 1, 'color': [163, 0, 255]},\n    {'id': 117, 'name': 'Mobile_phone', 'isthing': 1, 'color': [153, 0, 255]},\n    {'id': 118, 'name': 'keyboard', 'isthing': 1, 'color': [71, 255, 0]},\n    {'id': 122, 'name': 'instrument', 'isthing': 1, 'color': [0, 255, 235]},\n    {'id': 123, 'name': 'train', 'isthing': 1, 'color': [133, 255, 0]}\n]\n\nCLASSES_STUFF = [\n    {'id': 0, 'name': 'wall', 'isthing': 0, 'color': [120, 120, 120]},\n    {'id': 1, 'name': 'ceiling', 'isthing': 0, 'color': [180, 120, 120]},\n    {'id': 3, 'name': 'stair', 'isthing': 0, 'color': [80, 50, 50]},\n    {'id': 5, 'name': 'escalator', 'isthing': 0, 'color': [120, 120, 80]},\n    {'id': 6, 'name': 'Playground_slide', 'isthing': 0, 'color': [140, 140, 140]},\n    {'id': 7, 'name': 'handrail_or_fence', 'isthing': 0, 'color': [204, 5, 255]},\n    {'id': 9, 'name': 'rail', 'isthing': 0, 'color': [4, 250, 7]},\n    {'id': 11, 'name': 'pillar', 'isthing': 0, 'color': [235, 255, 7]},\n    {'id': 12, 'name': 'pole', 'isthing': 0, 'color': [150, 5, 61]},\n    {'id': 13, 'name': 'floor', 'isthing': 0, 'color': [120, 120, 70]},\n    {'id': 14, 'name': 'ground', 'isthing': 0, 'color': [8, 255, 51]},\n    {'id': 15, 'name': 'grass', 'isthing': 0, 'color': [255, 6, 82]},\n    {'id': 16, 'name': 'sand', 'isthing': 0, 'color': [143, 255, 140]},\n    {'id': 17, 'name': 'athletic_field', 'isthing': 0, 'color': [204, 255, 4]},\n    {'id': 18, 'name': 'road', 'isthing': 0, 'color': [255, 51, 7]},\n    {'id': 19, 'name': 'path', 'isthing': 0, 'color': [204, 70, 3]},\n    {'id': 20, 'name': 'crosswalk', 'isthing': 0, 'color': [0, 102, 200]},\n    {'id': 21, 'name': 'building', 'isthing': 0, 'color': [61, 230, 250]},\n    {'id': 22, 'name': 'house', 'isthing': 0, 'color': [255, 6, 51]},\n    {'id': 23, 'name': 'bridge', 'isthing': 0, 'color': [11, 102, 255]},\n    {'id': 24, 'name': 'tower', 'isthing': 0, 'color': [255, 7, 71]},\n    {'id': 25, 'name': 'windmill', 'isthing': 0, 'color': [255, 9, 224]},\n    {'id': 26, 'name': 'well_or_well_lid', 'isthing': 0, 'color': [9, 7, 230]},\n    {'id': 27, 'name': 'other_construction', 'isthing': 0, 'color': [220, 220, 220]},\n    {'id': 28, 'name': 'sky', 'isthing': 0, 'color': [255, 9, 92]},\n    {'id': 29, 'name': 'mountain', 'isthing': 0, 'color': [112, 9, 255]},\n    {'id': 30, 'name': 'stone', 'isthing': 0, 'color': [8, 255, 214]},\n    {'id': 31, 'name': 'wood', 'isthing': 0, 'color': [7, 255, 224]},\n    {'id': 32, 'name': 'ice', 'isthing': 0, 'color': [255, 184, 6]},\n    {'id': 33, 'name': 'snowfield', 'isthing': 0, 'color': [10, 255, 71]},\n    {'id': 34, 'name': 'grandstand', 'isthing': 0, 'color': [255, 41, 10]},\n    {'id': 35, 'name': 'sea', 'isthing': 0, 'color': [7, 255, 255]},\n    {'id': 36, 'name': 'river', 'isthing': 0, 'color': [224, 255, 8]},\n    {'id': 37, 'name': 'lake', 'isthing': 0, 'color': [102, 8, 255]},\n    {'id': 38, 'name': 'waterfall', 'isthing': 0, 'color': [255, 61, 6]},\n    {'id': 39, 'name': 'water', 'isthing': 0, 'color': [255, 194, 7]},\n    {'id': 40, 'name': 'billboard_or_Bulletin_Board', 'isthing': 0, 'color': [255, 122, 8]},\n    {'id': 42, 'name': 'pipeline', 'isthing': 0, 'color': [255, 8, 41]},\n    {'id': 45, 'name': 'cushion_or_carpet', 'isthing': 0, 'color': [235, 12, 255]},\n    {'id': 53, 'name': 'wheeled_machine', 'isthing': 0, 'color': [255, 224, 0]},\n    {'id': 57, 'name': 'tyre', 'isthing': 0, 'color': [0, 235, 255]},\n    {'id': 58, 'name': 'traffic_light', 'isthing': 0, 'color': [0, 173, 255]},\n    {'id': 59, 'name': 'lamp', 'isthing': 0, 'color': [31, 0, 255]},\n    {'id': 66, 'name': 'tree', 'isthing': 0, 'color': [255, 0, 0]},\n    {'id': 67, 'name': 'flower', 'isthing': 0, 'color': [255, 163, 0]},\n    {'id': 68, 'name': 'other_plant', 'isthing': 0, 'color': [255, 102, 0]},\n    {'id': 69, 'name': 'toy', 'isthing': 0, 'color': [194, 255, 0]},\n    {'id': 70, 'name': 'ball_net', 'isthing': 0, 'color': [0, 143, 255]},\n    {'id': 71, 'name': 'backboard', 'isthing': 0, 'color': [51, 255, 0]},\n    {'id': 73, 'name': 'bat', 'isthing': 0, 'color': [0, 255, 41]},\n    {'id': 75, 'name': 'cupboard_or_showcase_or_storage_rack', 'isthing': 0, 'color': [10, 0, 255]},\n    {'id': 80, 'name': 'trash_can', 'isthing': 0, 'color': [255, 0, 245]},\n    {'id': 81, 'name': 'cage', 'isthing': 0, 'color': [255, 0, 102]},\n    {'id': 93, 'name': 'shelf', 'isthing': 0, 'color': [51, 0, 255]},\n    {'id': 94, 'name': 'bathtub', 'isthing': 0, 'color': [0, 194, 255]},\n    {'id': 98, 'name': 'other_machine', 'isthing': 0, 'color': [0, 255, 10]},\n    {'id': 103, 'name': 'curtain', 'isthing': 0, 'color': [255, 235, 0]},\n    {'id': 104, 'name': 'textiles', 'isthing': 0, 'color': [8, 184, 170]},\n    {'id': 105, 'name': 'clothes', 'isthing': 0, 'color': [133, 0, 255]},\n    {'id': 110, 'name': 'book', 'isthing': 0, 'color': [0, 214, 255]},\n    {'id': 111, 'name': 'tool', 'isthing': 0, 'color': [255, 0, 112]},\n    {'id': 112, 'name': 'blackboard', 'isthing': 0, 'color': [92, 255, 0]},\n    {'id': 113, 'name': 'tissue', 'isthing': 0, 'color': [0, 224, 255]},\n    {'id': 119, 'name': 'other_electronic_product', 'isthing': 0, 'color': [255, 0, 163]},\n    {'id': 120, 'name': 'fruit', 'isthing': 0, 'color': [255, 204, 0]},\n    {'id': 121, 'name': 'food', 'isthing': 0, 'color': [255, 0, 143]}\n]\n\nNO_OBJ = 0\nNO_OBJ_HB = 255\nDIVISOR_PAN = 100\nDIVISOR_NEW = 1000\nNUM_THING = 58\nNUM_STUFF = 66\nTHING_B_STUFF = False\n\n\ndef vip2hb(pan_map):\n    assert not THING_B_STUFF, \"VIPSeg only supports stuff -> thing\"\n    pan_new = - np.ones_like(pan_map)\n    vip2hb_thing = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_THING)}\n    vip2hb_stuff = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_STUFF)}\n    for idx in np.unique(pan_map):\n        if idx == NO_OBJ or idx == 200:\n            pan_new[pan_map == idx] = NO_OBJ_HB * DIVISOR_NEW\n        elif idx > 128:\n            cls_id = idx // DIVISOR_PAN\n            cls_new_id = vip2hb_thing[cls_id]\n            inst_id = idx % DIVISOR_PAN\n            # since stuff -> thing\n            cls_new_id += NUM_STUFF\n            pan_new[pan_map == idx] = cls_new_id * DIVISOR_NEW + inst_id + 1\n        else:\n            pan_new[pan_map == idx] = vip2hb_stuff[idx] * DIVISOR_NEW\n    assert -1. not in np.unique(pan_new)\n    return pan_new\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Evaluation of DSTQ')\n    parser.add_argument('result_path')\n    parser.add_argument('--gt-path', default='data/kitti-step')\n    parser.add_argument('--split', default='val')\n    parser.add_argument(\n        '--depth',\n        action='store_true',\n        help='eval depth')\n    parser.add_argument('--nproc', default=32, type=int, help='number of process')\n    args = parser.parse_args()\n    return args\n\n\ndef vpq_eval(element):\n    pred_ids, gt_ids = element\n    max_ins = 2 ** 16\n    ign_id = 255\n    offset = 2 ** 30\n    num_cat = NUM_THING + NUM_STUFF + 1\n\n    iou_per_class = np.zeros(num_cat, dtype=np.float64)\n    tp_per_class = np.zeros(num_cat, dtype=np.float64)\n    fn_per_class = np.zeros(num_cat, dtype=np.float64)\n    fp_per_class = np.zeros(num_cat, dtype=np.float64)\n\n    def _ids_to_counts(id_array):\n        ids, counts = np.unique(id_array, return_counts=True)\n        return dict(six.moves.zip(ids, counts))\n\n    pred_areas = _ids_to_counts(pred_ids)\n    gt_areas = _ids_to_counts(gt_ids)\n\n    void_id = ign_id * max_ins\n    ign_ids = {\n        gt_id for gt_id in six.iterkeys(gt_areas)\n        if (gt_id // max_ins) == ign_id\n    }\n\n    int_ids = gt_ids.astype(np.int64) * offset + pred_ids.astype(np.int64)\n    int_areas = _ids_to_counts(int_ids)\n\n    def prediction_void_overlap(pred_id):\n        void_int_id = void_id * offset + pred_id\n        return int_areas.get(void_int_id, 0)\n\n    def prediction_ignored_overlap(pred_id):\n        total_ignored_overlap = 0\n        for _ign_id in ign_ids:\n            int_id = _ign_id * offset + pred_id\n            total_ignored_overlap += int_areas.get(int_id, 0)\n        return total_ignored_overlap\n\n    gt_matched = set()\n    pred_matched = set()\n\n    for int_id, int_area in six.iteritems(int_areas):\n        gt_id = int(int_id // offset)\n        gt_cat = int(gt_id // max_ins)\n        pred_id = int(int_id % offset)\n        pred_cat = int(pred_id // max_ins)\n        if gt_cat != pred_cat:\n            continue\n        union = (\n                gt_areas[gt_id] + pred_areas[pred_id] - int_area -\n                prediction_void_overlap(pred_id)\n        )\n        iou = int_area / union\n        if iou > 0.5:\n            tp_per_class[gt_cat] += 1\n            iou_per_class[gt_cat] += iou\n            gt_matched.add(gt_id)\n            pred_matched.add(pred_id)\n\n    for gt_id in six.iterkeys(gt_areas):\n        if gt_id in gt_matched:\n            continue\n        cat_id = gt_id // max_ins\n        if cat_id == ign_id:\n            continue\n        fn_per_class[cat_id] += 1\n\n    for pred_id in six.iterkeys(pred_areas):\n        if pred_id in pred_matched:\n            continue\n        if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5:\n            continue\n        cat = pred_id // max_ins\n        fp_per_class[cat] += 1\n\n    return iou_per_class, tp_per_class, fn_per_class, fp_per_class\n\n\ndef read_to_eval(element):\n    max_ins = 2 ** 16\n\n    pred_list, gt_list = element\n    pred_cat = [mmcv.imread(image[0], flag='unchanged').astype(np.int32) for image in pred_list]\n    pred_ins = [mmcv.imread(image[1], flag='unchanged').astype(np.int32) for image in pred_list]\n    pred_cat = np.concatenate(pred_cat, axis=1)\n    pred_ins = np.concatenate(pred_ins, axis=1)\n    pred = pred_cat.astype(np.int32) * max_ins + pred_ins.astype(np.int32)\n\n    gt_pan = [mmcv.imread(image, flag='unchanged').astype(np.int64) for image in gt_list]\n    gt_pan = np.concatenate(gt_pan, axis=1)\n    gt_pan = vip2hb(gt_pan)\n\n    gt_cls = gt_pan // DIVISOR_NEW\n    gt_ins = gt_pan % DIVISOR_NEW\n\n    gt = gt_cls * max_ins + gt_ins\n    result = vpq_eval([pred, gt])\n\n    return result\n\n\ndef eval_dvpq(result_dir, gt_dir, split='val', k=1, with_depth=True):\n    if with_depth:\n        raise NotImplementedError\n    ann_folders = mmcv.list_from_file(os.path.join(gt_dir, \"{}.txt\".format(split)),\n                                      prefix=os.path.join(gt_dir, 'panomasks') + '/')\n    seq_ids = np.arange(0, len(ann_folders)).tolist()\n\n    iou_per_class_all = []\n    tp_per_class_all = []\n    fn_per_class_all = []\n    fp_per_class_all = []\n\n    for seq_id in seq_ids:\n        gt_names = list(mmcv.scandir(ann_folders[seq_id]))\n        gt_pan_names = sorted(list(filter(lambda x: '.png' in x, gt_names)))\n        if not os.path.exists(os.path.join(result_dir, 'panoptic', str(seq_id))):\n            print(\"Error when seq_id is {}. But cal existing seqs.\".format(seq_id))\n            break\n        pred_name_panoptic = list(mmcv.scandir(os.path.join(result_dir, 'panoptic', str(seq_id))))\n        pred_ins_names = sorted(list(filter(lambda x: 'ins' in x, pred_name_panoptic)))\n        pred_cls_names = sorted(list(filter(lambda x: 'cat' in x, pred_name_panoptic)))\n        if len(gt_pan_names) != len(pred_ins_names):\n            print(\"Error when seq_id is {}. But cal existing seqs.\".format(seq_id))\n            break\n        elements = []\n        assert len(pred_ins_names) == len(pred_cls_names)\n        assert len(pred_cls_names) == len(gt_pan_names)\n        len_seq = len(pred_ins_names)\n\n        k = min(k, len_seq)\n\n        for idx in range(len_seq):\n            if idx + k - 1 >= len_seq:\n                break\n            pred = []\n            gt = []\n            for j in range(k):\n                pred_cur = (os.path.join(result_dir, 'panoptic', str(seq_id), pred_cls_names[idx + j]),\n                            os.path.join(result_dir, 'panoptic', str(seq_id), pred_cls_names[idx + j]))\n                gt_cur = os.path.join(ann_folders[seq_id], gt_pan_names[idx + j])\n                pred.append(pred_cur)\n                gt.append(gt_cur)\n            elements.append((pred, gt))\n\n        N = mp.cpu_count()\n        with mp.Pool(processes=N) as p:\n            results = p.map(read_to_eval, elements)\n\n        iou_per_class = np.stack([result[0] for result in results])\n        iou_per_class_all.append(iou_per_class)\n        tp_per_class = np.stack([result[1] for result in results])\n        tp_per_class_all.append(tp_per_class)\n        fn_per_class = np.stack([result[2] for result in results])\n        fn_per_class_all.append(fn_per_class)\n        fp_per_class = np.stack([result[3] for result in results])\n        fp_per_class_all.append(fp_per_class)\n\n    epsilon = 1e-10\n    iou_per_class_all = np.concatenate(iou_per_class_all, axis=0).sum(axis=0)[:NUM_THING + NUM_STUFF]\n    tp_per_class_all = np.concatenate(tp_per_class_all, axis=0).sum(axis=0)[:NUM_THING + NUM_STUFF]\n    fn_per_class_all = np.concatenate(fn_per_class_all, axis=0).sum(axis=0)[:NUM_THING + NUM_STUFF]\n    fp_per_class_all = np.concatenate(fp_per_class_all, axis=0).sum(axis=0)[:NUM_THING + NUM_STUFF]\n\n    sq = iou_per_class_all / (tp_per_class_all + epsilon)\n    rq = tp_per_class_all / (tp_per_class_all + 0.5 * fn_per_class_all + 0.5 * fp_per_class_all + epsilon)\n    pq = sq * rq\n    spq = pq[:NUM_STUFF]\n    tpq = pq[NUM_STUFF:]\n    print(\n        r'PQ : {:.3f} PQ_thing : {:.3f} PQ_stuff : {:.3f}'.format(\n            pq.mean() * 100,\n            tpq.mean() * 100,\n            spq.mean() * 100)\n    )\n\n\n# usage python eval_dstq_vipseg.py /opt/data/results/test --gt-path /opt/data/VIPSeg\nif __name__ == '__main__':\n    args = parse_args()\n    result_path = args.result_path\n    gt_path = args.gt_path\n    split = args.split\n    for k in [1, 2, 4, 6]:\n        print(\"k={}\".format(k))\n        eval_dvpq(result_path, gt_path, split=split, with_depth=args.depth, k=k)\n"
  },
  {
    "path": "tools/flops_counter.py",
    "content": "'''\nCopyright (C) 2019 Sovrasov V. - All Rights Reserved\n * You may use, distribute and modify this code under the\n * terms of the MIT license.\n * You should have received a copy of the MIT license with\n * this file. If not visit https://opensource.org/licenses/MIT\n'''\n\nimport sys\nfrom functools import partial\n\nimport mmcv.cnn.bricks.transformer\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nimport mmcv\n\ndef get_model_complexity_info(model, input_res,\n                              print_per_layer_stat=True,\n                              as_strings=True,\n                              input_constructor=None, ost=sys.stdout,\n                              verbose=False, ignore_modules=[],\n                              custom_modules_hooks={}):\n    assert type(input_res) is tuple\n    assert len(input_res) >= 1\n    assert isinstance(model, nn.Module)\n    global CUSTOM_MODULES_MAPPING\n    CUSTOM_MODULES_MAPPING = custom_modules_hooks\n    flops_model = add_flops_counting_methods(model)\n    flops_model.eval()\n    flops_model.start_flops_count(ost=ost, verbose=verbose,\n                                  ignore_list=ignore_modules)\n    if input_constructor:\n        input = input_constructor(input_res)\n        _ = flops_model(**input)\n    else:\n        try:\n            batch = torch.ones(()).new_empty((1, *input_res),\n                                             dtype=next(flops_model.parameters()).dtype,\n                                             device=next(flops_model.parameters()).device)\n        except StopIteration:\n            batch = torch.ones(()).new_empty((1, *input_res))\n\n        _ = flops_model(batch)\n\n    flops_count, params_count = flops_model.compute_average_flops_cost()\n    if print_per_layer_stat:\n        print_model_with_flops(flops_model, flops_count, params_count, ost=ost)\n    flops_model.stop_flops_count()\n    CUSTOM_MODULES_MAPPING = {}\n\n    if as_strings:\n        return flops_to_string(flops_count), params_to_string(params_count)\n\n    return flops_count, params_count\n\n\ndef flops_to_string(flops, units='GMac', precision=2):\n    if units is None:\n        if flops // 10**9 > 0:\n            return str(round(flops / 10.**9, precision)) + ' GMac'\n        elif flops // 10**6 > 0:\n            return str(round(flops / 10.**6, precision)) + ' MMac'\n        elif flops // 10**3 > 0:\n            return str(round(flops / 10.**3, precision)) + ' KMac'\n        else:\n            return str(flops) + ' Mac'\n    else:\n        if units == 'GMac':\n            return str(round(flops / 10.**9, precision)) + ' ' + units\n        elif units == 'MMac':\n            return str(round(flops / 10.**6, precision)) + ' ' + units\n        elif units == 'KMac':\n            return str(round(flops / 10.**3, precision)) + ' ' + units\n        else:\n            return str(flops) + ' Mac'\n\n\ndef params_to_string(params_num, units=None, precision=2):\n    if units is None:\n        if params_num // 10 ** 6 > 0:\n            return str(round(params_num / 10 ** 6, 2)) + ' M'\n        elif params_num // 10 ** 3:\n            return str(round(params_num / 10 ** 3, 2)) + ' k'\n        else:\n            return str(params_num)\n    else:\n        if units == 'M':\n            return str(round(params_num / 10.**6, precision)) + ' ' + units\n        elif units == 'K':\n            return str(round(params_num / 10.**3, precision)) + ' ' + units\n        else:\n            return str(params_num)\n\n\ndef accumulate_flops(self):\n    if is_supported_instance(self):\n        return self.__flops__\n    else:\n        sum = 0\n        for m in self.children():\n            sum += m.accumulate_flops()\n        return sum\n\n\ndef print_model_with_flops(model, total_flops, total_params, units='GMac',\n                           precision=3, ost=sys.stdout):\n    if total_flops < 1:\n        total_flops = 1\n\n    def accumulate_params(self):\n        if is_supported_instance(self):\n            return self.__params__\n        else:\n            sum = 0\n            for m in self.children():\n                sum += m.accumulate_params()\n            return sum\n\n    def flops_repr(self):\n        accumulated_params_num = self.accumulate_params()\n        accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__\n        return ', '.join([params_to_string(accumulated_params_num,\n                                           units='M', precision=precision),\n                          '{:.3%} Params'.format(accumulated_params_num / total_params),\n                          flops_to_string(accumulated_flops_cost,\n                                          units=units, precision=precision),\n                          '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),\n                          self.original_extra_repr()])\n\n    def add_extra_repr(m):\n        m.accumulate_flops = accumulate_flops.__get__(m)\n        m.accumulate_params = accumulate_params.__get__(m)\n        flops_extra_repr = flops_repr.__get__(m)\n        if m.extra_repr != flops_extra_repr:\n            m.original_extra_repr = m.extra_repr\n            m.extra_repr = flops_extra_repr\n            assert m.extra_repr != m.original_extra_repr\n\n    def del_extra_repr(m):\n        if hasattr(m, 'original_extra_repr'):\n            m.extra_repr = m.original_extra_repr\n            del m.original_extra_repr\n        if hasattr(m, 'accumulate_flops'):\n            del m.accumulate_flops\n\n    model.apply(add_extra_repr)\n    print(repr(model), file=ost)\n    model.apply(del_extra_repr)\n\n\ndef get_model_parameters_number(model):\n    params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    return params_num\n\n\ndef add_flops_counting_methods(net_main_module):\n    # adding additional methods to the existing module object,\n    # this is done this way so that each function has access to self object\n    net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)\n    net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)\n    net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)\n    net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(\n                                                    net_main_module)\n\n    net_main_module.reset_flops_count()\n\n    return net_main_module\n\n\ndef compute_average_flops_cost(self):\n    \"\"\"\n    A method that will be available after add_flops_counting_methods() is called\n    on a desired net object.\n    Returns current mean flops consumption per image.\n    \"\"\"\n\n    for m in self.modules():\n        m.accumulate_flops = accumulate_flops.__get__(m)\n\n    flops_sum = self.accumulate_flops()\n\n    for m in self.modules():\n        if hasattr(m, 'accumulate_flops'):\n            del m.accumulate_flops\n\n    params_sum = get_model_parameters_number(self)\n    return flops_sum / self.__batch_counter__, params_sum\n\n\ndef start_flops_count(self, **kwargs):\n    \"\"\"\n    A method that will be available after add_flops_counting_methods() is called\n    on a desired net object.\n    Activates the computation of mean flops consumption per image.\n    Call it before you run the network.\n    \"\"\"\n    add_batch_counter_hook_function(self)\n\n    seen_types = set()\n\n    def add_flops_counter_hook_function(module, ost, verbose, ignore_list):\n        if type(module) in ignore_list:\n            seen_types.add(type(module))\n            if is_supported_instance(module):\n                module.__params__ = 0\n        elif is_supported_instance(module):\n            if hasattr(module, '__flops_handle__'):\n                return\n            if type(module) in CUSTOM_MODULES_MAPPING:\n                handle = module.register_forward_hook(\n                                        CUSTOM_MODULES_MAPPING[type(module)])\n            else:\n                handle = module.register_forward_hook(MODULES_MAPPING[type(module)])\n            module.__flops_handle__ = handle\n            seen_types.add(type(module))\n        else:\n            if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \\\n               not type(module) in seen_types:\n                print('Warning: module ' + type(module).__name__ +\n                      ' is treated as a zero-op.', file=ost)\n            seen_types.add(type(module))\n\n    self.apply(partial(add_flops_counter_hook_function, **kwargs))\n\n\ndef stop_flops_count(self):\n    \"\"\"\n    A method that will be available after add_flops_counting_methods() is called\n    on a desired net object.\n    Stops computing the mean flops consumption per image.\n    Call whenever you want to pause the computation.\n    \"\"\"\n    remove_batch_counter_hook_function(self)\n    self.apply(remove_flops_counter_hook_function)\n\n\ndef reset_flops_count(self):\n    \"\"\"\n    A method that will be available after add_flops_counting_methods() is called\n    on a desired net object.\n    Resets statistics computed so far.\n    \"\"\"\n    add_batch_counter_variables_or_reset(self)\n    self.apply(add_flops_counter_variable_or_reset)\n\n\n# ---- Internal functions\ndef empty_flops_counter_hook(module, input, output):\n    module.__flops__ += 0\n\n\ndef upsample_flops_counter_hook(module, input, output):\n    output_size = output[0]\n    batch_size = output_size.shape[0]\n    output_elements_count = batch_size\n    for val in output_size.shape[1:]:\n        output_elements_count *= val\n    module.__flops__ += int(output_elements_count)\n\n\ndef relu_flops_counter_hook(module, input, output):\n    active_elements_count = output.numel()\n    module.__flops__ += int(active_elements_count)\n\n\ndef linear_flops_counter_hook(module, input, output):\n    input = input[0]\n    # pytorch checks dimensions, so here we don't care much\n    output_last_dim = output.shape[-1]\n    bias_flops = output_last_dim if module.bias is not None else 0\n    module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops)\n\n\ndef pool_flops_counter_hook(module, input, output):\n    input = input[0]\n    module.__flops__ += int(np.prod(input.shape))\n\n\ndef bn_flops_counter_hook(module, input, output):\n    input = input[0]\n\n    batch_flops = np.prod(input.shape)\n    if module.affine:\n        batch_flops *= 2\n    module.__flops__ += int(batch_flops)\n\n\ndef conv_flops_counter_hook(conv_module, input, output):\n    # Can have multiple inputs, getting the first one\n    input = input[0]\n\n    batch_size = input.shape[0]\n    output_dims = list(output.shape[2:])\n\n    kernel_dims = list(conv_module.kernel_size)\n    in_channels = conv_module.in_channels\n    out_channels = conv_module.out_channels\n    groups = conv_module.groups\n\n    filters_per_channel = out_channels // groups\n    conv_per_position_flops = int(np.prod(kernel_dims)) * \\\n        in_channels * filters_per_channel\n\n    active_elements_count = batch_size * int(np.prod(output_dims))\n\n    overall_conv_flops = conv_per_position_flops * active_elements_count\n\n    bias_flops = 0\n\n    if conv_module.bias is not None:\n\n        bias_flops = out_channels * active_elements_count\n\n    overall_flops = overall_conv_flops + bias_flops\n\n    conv_module.__flops__ += int(overall_flops)\n\n\ndef batch_counter_hook(module, input, output):\n    batch_size = 1\n    if len(input) > 0:\n        # Can have multiple inputs, getting the first one\n        input = input[0]\n        batch_size = len(input)\n    else:\n        pass\n        print('Warning! No positional inputs found for a module,'\n              ' assuming batch size is 1.')\n    module.__batch_counter__ += batch_size\n\n\ndef rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):\n    # matrix matrix mult ih state and internal state\n    flops += w_ih.shape[0]*w_ih.shape[1]\n    # matrix matrix mult hh state and internal state\n    flops += w_hh.shape[0]*w_hh.shape[1]\n    if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):\n        # add both operations\n        flops += rnn_module.hidden_size\n    elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):\n        # hadamard of r\n        flops += rnn_module.hidden_size\n        # adding operations from both states\n        flops += rnn_module.hidden_size*3\n        # last two hadamard product and add\n        flops += rnn_module.hidden_size*3\n    elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):\n        # adding operations from both states\n        flops += rnn_module.hidden_size*4\n        # two hadamard product and add for C state\n        flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size\n        # final hadamard\n        flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size\n    return flops\n\n\ndef rnn_flops_counter_hook(rnn_module, input, output):\n    \"\"\"\n    Takes into account batch goes at first position, contrary\n    to pytorch common rule (but actually it doesn't matter).\n    IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate\n    \"\"\"\n    flops = 0\n    # input is a tuple containing a sequence to process and (optionally) hidden state\n    inp = input[0]\n    batch_size = inp.shape[0]\n    seq_length = inp.shape[1]\n    num_layers = rnn_module.num_layers\n\n    for i in range(num_layers):\n        w_ih = rnn_module.__getattr__('weight_ih_l' + str(i))\n        w_hh = rnn_module.__getattr__('weight_hh_l' + str(i))\n        if i == 0:\n            input_size = rnn_module.input_size\n        else:\n            input_size = rnn_module.hidden_size\n        flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)\n        if rnn_module.bias:\n            b_ih = rnn_module.__getattr__('bias_ih_l' + str(i))\n            b_hh = rnn_module.__getattr__('bias_hh_l' + str(i))\n            flops += b_ih.shape[0] + b_hh.shape[0]\n\n    flops *= batch_size\n    flops *= seq_length\n    if rnn_module.bidirectional:\n        flops *= 2\n    rnn_module.__flops__ += int(flops)\n\n\ndef rnn_cell_flops_counter_hook(rnn_cell_module, input, output):\n    flops = 0\n    inp = input[0]\n    batch_size = inp.shape[0]\n    w_ih = rnn_cell_module.__getattr__('weight_ih')\n    w_hh = rnn_cell_module.__getattr__('weight_hh')\n    input_size = inp.shape[1]\n    flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)\n    if rnn_cell_module.bias:\n        b_ih = rnn_cell_module.__getattr__('bias_ih')\n        b_hh = rnn_cell_module.__getattr__('bias_hh')\n        flops += b_ih.shape[0] + b_hh.shape[0]\n\n    flops *= batch_size\n    rnn_cell_module.__flops__ += int(flops)\n\ndef ffn_hook(module, input, output):\n    input = input[0]\n    for layer in module.layers:\n        if isinstance(layer, nn.Sequential):\n            layer_cur = layer[0]\n        else:\n            layer_cur = layer\n        if not isinstance(layer_cur, nn.Linear):\n            continue\n        # pytorch checks dimensions, so here we don't care much\n        output_last_dim = layer_cur.out_features\n        bias_flops = output_last_dim if layer_cur.bias is not None else 0\n        module.__flops__ += int(input.shape[0] * input.shape[1] * layer_cur.in_features) * output_last_dim + bias_flops\n\ndef multihead_attention_counter_hook(multihead_attention_module, input, output):\n    flops = 0\n    if len(input) == 0:\n        print(len(output))\n        for i in output:\n            print(i.shape)\n        # unknown problem\n        q,k,v = output[0], output[0], output[0]\n    else:\n        print(\"Successful!\")\n        q, k, v = input\n    batch_size = q.shape[1]\n\n    num_heads = multihead_attention_module.num_heads\n    embed_dims = multihead_attention_module.embed_dims\n    kdim = multihead_attention_module.kdim\n    vdim = multihead_attention_module.vdim\n    if kdim is None:\n        kdim = embed_dims\n    if vdim is None:\n        vdim = embed_dims\n\n    # initial projections\n    flops = q.shape[0] * q.shape[2] * embed_dims + \\\n        k.shape[0] * k.shape[2] * kdim + \\\n        v.shape[0] * v.shape[2] * vdim\n    if multihead_attention_module.in_proj_bias is not None:\n        flops += (q.shape[0] + k.shape[0] + v.shape[0]) * embed_dims\n\n    # attention heads: scale, matmul, softmax, matmul\n    head_dim = embed_dims // num_heads\n    head_flops = q.shape[0] * head_dim + \\\n        head_dim * q.shape[0] * k.shape[0] + \\\n        q.shape[0] * k.shape[0] + \\\n        q.shape[0] * k.shape[0] * head_dim\n\n    flops += num_heads * head_flops\n\n    # final projection, bias is always enabled\n    flops += q.shape[0] * embed_dims * (embed_dims + 1)\n\n    flops *= batch_size\n    multihead_attention_module.__flops__ += int(flops)\n\n\ndef add_batch_counter_variables_or_reset(module):\n\n    module.__batch_counter__ = 0\n\n\ndef add_batch_counter_hook_function(module):\n    if hasattr(module, '__batch_counter_handle__'):\n        return\n\n    handle = module.register_forward_hook(batch_counter_hook)\n    module.__batch_counter_handle__ = handle\n\n\ndef remove_batch_counter_hook_function(module):\n    if hasattr(module, '__batch_counter_handle__'):\n        module.__batch_counter_handle__.remove()\n        del module.__batch_counter_handle__\n\n\ndef add_flops_counter_variable_or_reset(module):\n    if is_supported_instance(module):\n        if hasattr(module, '__flops__') or hasattr(module, '__params__'):\n            print('Warning: variables __flops__ or __params__ are already '\n                  'defined for the module' + type(module).__name__ +\n                  ' ptflops can affect your code!')\n        module.__flops__ = 0\n        module.__params__ = get_model_parameters_number(module)\n\n\nCUSTOM_MODULES_MAPPING = {}\n\ndef norm_flops_counter_hook(module, input, output):\n    input = input[0]\n\n    batch_flops = np.prod(input.shape)\n    if (getattr(module, 'affine', False)\n            or getattr(module, 'elementwise_affine', False)):\n        batch_flops *= 2\n    module.__flops__ += int(batch_flops)\n\nMODULES_MAPPING = {\n    # convolutions\n    nn.Conv1d: conv_flops_counter_hook,\n    nn.Conv2d: conv_flops_counter_hook,\n    nn.Conv3d: conv_flops_counter_hook,\n    # activations\n    nn.ReLU: relu_flops_counter_hook,\n    nn.PReLU: relu_flops_counter_hook,\n    nn.ELU: relu_flops_counter_hook,\n    nn.LeakyReLU: relu_flops_counter_hook,\n    nn.ReLU6: relu_flops_counter_hook,\n    # poolings\n    nn.MaxPool1d: pool_flops_counter_hook,\n    nn.AvgPool1d: pool_flops_counter_hook,\n    nn.AvgPool2d: pool_flops_counter_hook,\n    nn.MaxPool2d: pool_flops_counter_hook,\n    nn.MaxPool3d: pool_flops_counter_hook,\n    nn.AvgPool3d: pool_flops_counter_hook,\n    nn.AdaptiveMaxPool1d: pool_flops_counter_hook,\n    nn.AdaptiveAvgPool1d: pool_flops_counter_hook,\n    nn.AdaptiveMaxPool2d: pool_flops_counter_hook,\n    nn.AdaptiveAvgPool2d: pool_flops_counter_hook,\n    nn.AdaptiveMaxPool3d: pool_flops_counter_hook,\n    nn.AdaptiveAvgPool3d: pool_flops_counter_hook,\n    # BNs\n    nn.BatchNorm1d: bn_flops_counter_hook,\n    nn.BatchNorm2d: bn_flops_counter_hook,\n    nn.BatchNorm3d: bn_flops_counter_hook,\n\n    nn.InstanceNorm1d: bn_flops_counter_hook,\n    nn.InstanceNorm2d: bn_flops_counter_hook,\n    nn.InstanceNorm3d: bn_flops_counter_hook,\n    nn.GroupNorm: bn_flops_counter_hook,\n\n    # normalizations\n    # nn.BatchNorm1d: norm_flops_counter_hook,\n    # nn.BatchNorm2d: norm_flops_counter_hook,\n    # nn.BatchNorm3d: norm_flops_counter_hook,\n    # nn.GroupNorm: norm_flops_counter_hook,\n    # nn.InstanceNorm1d: norm_flops_counter_hook,\n    # nn.InstanceNorm2d: norm_flops_counter_hook,\n    # nn.InstanceNorm3d: norm_flops_counter_hook,\n    nn.LayerNorm: norm_flops_counter_hook,\n\n    # FC\n    nn.Linear: linear_flops_counter_hook,\n    # Upscale\n    nn.Upsample: upsample_flops_counter_hook,\n    # Deconvolution\n    nn.ConvTranspose1d: conv_flops_counter_hook,\n    nn.ConvTranspose2d: conv_flops_counter_hook,\n    nn.ConvTranspose3d: conv_flops_counter_hook,\n    # RNN\n    nn.RNN: rnn_flops_counter_hook,\n    nn.GRU: rnn_flops_counter_hook,\n    nn.LSTM: rnn_flops_counter_hook,\n    nn.RNNCell: rnn_cell_flops_counter_hook,\n    nn.LSTMCell: rnn_cell_flops_counter_hook,\n    nn.GRUCell: rnn_cell_flops_counter_hook,\n    nn.MultiheadAttention: multihead_attention_counter_hook,\n\n    mmcv.cnn.bricks.transformer.FFN:ffn_hook\n}\n\n\n\ndef is_supported_instance(module):\n    if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING:\n        return True\n    return False\n\n\ndef remove_flops_counter_hook_function(module):\n    if is_supported_instance(module):\n        if hasattr(module, '__flops_handle__'):\n            module.__flops_handle__.remove()\n            del module.__flops_handle__\n"
  },
  {
    "path": "tools/get_flops.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport argparse\n\nimport numpy as np\nimport torch\nfrom mmcv import Config, DictAction\n\nfrom mmdet.models import build_detector\n\ntry:\n    from mmcv.cnn import get_model_complexity_info\n    # from tools.flops_counter import get_model_complexity_info\nexcept ImportError:\n    raise ImportError('Please upgrade mmcv to >0.6.2')\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a detector')\n    parser.add_argument('config', help='train config file path')\n    parser.add_argument(\n        '--shape',\n        type=int,\n        nargs='+',\n        default=[1280, 800],\n        help='input image size')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--size-divisor',\n        type=int,\n        default=32,\n        help='Pad the input image, the minimum size that is divisible '\n        'by size_divisor, -1 means do not pad the image.')\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n\n    args = parse_args()\n\n    if len(args.shape) == 1:\n        h = w = args.shape[0]\n    elif len(args.shape) == 2:\n        h, w = args.shape\n    else:\n        raise ValueError('invalid input shape')\n    orig_shape = (3, h, w)\n    divisor = args.size_divisor\n    if divisor > 0:\n        h = int(np.ceil(h / divisor)) * divisor\n        w = int(np.ceil(w / divisor)) * divisor\n\n    input_shape = (3, h, w)\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n\n    model = build_detector(\n        cfg.model,\n        train_cfg=cfg.get('train_cfg'),\n        test_cfg=cfg.get('test_cfg'))\n    if torch.cuda.is_available():\n        model.cuda()\n    model.eval()\n\n    if hasattr(model, 'forward_dummy'):\n        model.forward = model.forward_dummy\n    else:\n        raise NotImplementedError(\n            'FLOPs counter is currently not currently supported with {}'.\n            format(model.__class__.__name__))\n\n    flops, params = get_model_complexity_info(model, input_shape)\n    split_line = '=' * 30\n\n    if divisor > 0 and \\\n            input_shape != orig_shape:\n        print(f'{split_line}\\nUse size divisor set input shape '\n              f'from {orig_shape} to {input_shape}\\n')\n    print(f'{split_line}\\nInput shape: {input_shape}\\n'\n          f'Flops: {flops}\\nParams: {params}\\n{split_line}')\n    print('!!!Please be cautious if you use the results in papers. '\n          'You may need to check if all ops are supported and verify that the '\n          'flops computation is correct.')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/inference_kitti_step.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nCHECKPOINT=$2\nLOG=$3\n\n# configs/det/video_knet_kitti_step/video_knet_s3_r50_rpn_1x_kitti_step_sigmoid_stride2.py logger/models/video_knet_vis/video_knet_step_quansi_r50.pth logger/results/kitti_step_merge_joint_semantic_filter\n# configs/det/video_knet_kitti_step/video_knet_s3_r50_rpn_1x_kitti_step_sigmoid_stride2.py logger/models/video_knet_vis/video_knet_step_quansi_r50.pth logger/results/kitti_step_semantic_filter\n\n# --cfg-options data.test.split=val model.roi_head.merge_joint=True model.semantic_filter=True\n# --cfg-options data.test.split=val model.roi_head.merge_joint=False model.semantic_filter=True\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython $(dirname \"$0\")/test_dvps.py $CONFIG $CHECKPOINT --eval dummy --show-dir $LOG ${@:4}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython $(dirname \"$0\")/eval_dstq_step.py $LOG\n"
  },
  {
    "path": "tools/slurm_test.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-8}\nGPUS_PER_NODE=${GPUS_PER_NODE:-8}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}\n"
  },
  {
    "path": "tools/slurm_test_dvps.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-1}\nGPUS_PER_NODE=${GPUS_PER_NODE:-1}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools/test_dvps.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}"
  },
  {
    "path": "tools/slurm_test_step.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-1}\nGPUS_PER_NODE=${GPUS_PER_NODE:-1}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools/test_step.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}"
  },
  {
    "path": "tools/slurm_test_vis.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-1}\nGPUS_PER_NODE=${GPUS_PER_NODE:-1}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools/test_vis.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}"
  },
  {
    "path": "tools/slurm_test_vps.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-1}\nGPUS_PER_NODE=${GPUS_PER_NODE:-8}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools/test_vps_two_frames.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}"
  },
  {
    "path": "tools/slurm_train.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nWORK_DIR=$4\nGPUS=${GPUS:-8}\nGPUS_PER_NODE=${GPUS_PER_NODE:-8}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\nPY_ARGS=${@:5}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher=\"slurm\" ${PY_ARGS}\n"
  },
  {
    "path": "tools/test.py",
    "content": "import argparse\nimport os\nimport warnings\n\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\nfrom external.test import multi_gpu_test, single_gpu_test\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n        'the inference speed')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n        ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function (deprecate), '\n        'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    # cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=[0])\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  args.show_score_thr)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir,\n                                 args.gpu_collect)\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            eval_kwargs = cfg.get('evaluation', {}).copy()\n            # hard-code way to remove EvalHook args\n            for key in [\n                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',\n                    'rule', 'by_epoch'\n            ]:\n                eval_kwargs.pop(key, None)\n            eval_kwargs.update(dict(metric=args.eval, **kwargs))\n            print(dataset.evaluate(outputs, **eval_kwargs))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/test_dvps.py",
    "content": "import argparse\nimport os\nimport os.path as osp\nimport warnings\nimport numpy as np\nimport pickle\nimport json\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\nfrom external.test import encode_mask_results, tensor2imgs\n\n\ndef single_gpu_test(model,\n                    data_loader,\n                    show=False,\n                    out_dir=None,\n                    show_score_thr=0.3,\n                    with_semantic_input=False,\n                    rescale_depth=False,\n                    with_seq=False,\n                    ):\n    if out_dir is None:\n        out_dir = './out'\n    print(\"The output dir is {}\".format(out_dir))\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n\n    pano_seg_2ch_list = []\n\n    for i, data in enumerate(data_loader):\n        seq_id = data['seq_id'][0].item()\n        img_id = data['img_id'][0].item()\n        data.pop('seq_id')\n        if with_semantic_input:\n            semantic_input = mmcv.imread(\n                os.path.join('data/kitti-dvps/semantic/',\n                             \"{:06d}_{:06d}_semantic.png\".format(seq_id, img_id)), flag='unchanged')\n            semantic_input = torch.tensor(semantic_input, device=data['img'][0].device)\n        else:\n            semantic_input = None\n\n        with torch.no_grad():\n            segm_results = model(return_loss=False, rescale=True, semantic_input=semantic_input, **data)\n\n        sseg_results, track_maps, depth_final, vis_sem, vis_tracker = segm_results\n        batch_size = 1\n\n        # dump results\n        seq_folder = str(seq_id) if with_seq else \"\"\n        cat_path = os.path.join(out_dir, 'panoptic', seq_folder, '{:06d}_{:06d}_cat.png'.format(seq_id, img_id))\n        ins_path = os.path.join(out_dir, 'panoptic', seq_folder, '{:06d}_{:06d}_ins.png'.format(seq_id, img_id))\n        dep_path = os.path.join(out_dir, 'depth', seq_folder, '{:06d}_{:06d}.png'.format(seq_id, img_id))\n        vis_path = os.path.join(out_dir, 'vis', seq_folder, '{:06d}_{:06d}.png'.format(seq_id, img_id))\n        depth_final_rescale = mmcv.imresize(depth_final, (300, 100), interpolation='bilinear') \\\n            if depth_final is not None else None\n        mmcv.imwrite(sseg_results.astype(np.uint16), cat_path)\n        mmcv.imwrite(track_maps.astype(np.uint16), ins_path)\n        if depth_final_rescale is not None:\n            mmcv.imwrite(((depth_final_rescale if rescale_depth else depth_final) * 256.).astype(np.uint16), dep_path)\n        mmcv.imwrite(np.concatenate((vis_sem, vis_tracker), axis=0), vis_path)\n\n        for _ in range(batch_size):\n            prog_bar.update()\n\n    return results, pano_seg_2ch_list\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n             'the inference speed')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n             'useful when you want to format the result to a specific format and '\n             'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n             ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n             'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n             'in xxx=yyy format will be merged into config file. If the value to '\n             'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n             'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n             'Note that the quotation marks are necessary and that no white space '\n             'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function (deprecate), '\n             'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function')\n\n    parser.add_argument(\n        '--semantic',\n        action='store_true',\n        help=\"semantic input\"\n    )\n    parser.add_argument(\n        '--rescale-depth',\n        action='store_true',\n        help=\"\"\n    )\n    parser.add_argument(\n        '--with-seq',\n        action='store_true',\n        help=\"\"\n    )\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    # parser.add_argument('--output_dir', default=\"./work_dirs/vps/vps_output\",\n    #                     help='output result file in pickle format to load')\n    # parser.add_argument('--n_video', type=int, default=50, help=\"number of video clips\")\n    # parser.add_argument('--pan_im_json_file', type=str, default='data/cityscapes_vps/panoptic_im_val_city_vps.json')\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n           or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n    print(args)\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    model = MMDataParallel(model, device_ids=[0])\n    # Inference the sequence\n    outputs, pred_pans_2ch = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                             args.show_score_thr, with_semantic_input=args.semantic,\n                                             rescale_depth=args.rescale_depth, with_seq=args.with_seq)\n    print(\"==>Inference Depth VPS Done!\")\n\n    # Evaluation Part\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/test_step.py",
    "content": "import argparse\nimport os\nimport os.path as osp\nimport warnings\nimport numpy as np\nimport pickle\nimport json\nimport cv2\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\ntry:\n    from mmcv.cnn import get_model_complexity_info\nexcept ImportError:\n    raise ImportError('Please upgrade mmcv to >0.6.2')\n\ndef single_gpu_test(model,\n                    data_loader,\n                    show=False,\n                    out_dir=None,\n                    show_score_thr=0.3,\n                    with_semantic_input=False,):\n    if out_dir is None:\n        out_dir = './out'\n    print(\"The output dir is {}\".format(out_dir))\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n\n    pano_seg_2ch_list = []\n\n    for i, data in enumerate(data_loader):\n        seq_id = data['seq_id'][0].item()\n        img_id = data['img_id'][0].item()\n        data.pop('seq_id')\n\n        with torch.no_grad():\n            segm_results = model(return_loss=False, rescale=True, **data)\n\n        sseg_results, track_maps, _, _, _ = segm_results\n        batch_size = 1\n        # merge\n\n        # dump results\n        cat_path = os.path.join(out_dir, 'panoptic', str(seq_id), '{:06d}_{:06d}_cat.png'.format(seq_id, img_id))\n        ins_path = os.path.join(out_dir, 'panoptic', str(seq_id), '{:06d}_{:06d}_ins.png'.format(seq_id, img_id))\n        vis_path = os.path.join(out_dir, 'vis', str(seq_id), '{:06d}_{:06d}.png'.format(seq_id, img_id))\n        final_path = os.path.join(out_dir, 'final', '{:04d}'.format(seq_id), '{:06d}.png'.format(img_id))\n\n        # depth_final_rescale = mmcv.imresize(depth_final, (300, 100), interpolation='bilinear') \\\n        #     if depth_final is not None else None\n        final_map = np.stack([sseg_results.astype(np.uint8), (track_maps // 256).astype(np.uint8), (track_maps % 256).astype(np.uint8)], axis=-1)\n        cv2.cvtColor(final_map, cv2.COLOR_RGB2BGR, final_map)\n        mmcv.imwrite(sseg_results.astype(np.uint16), cat_path)\n        mmcv.imwrite(track_maps.astype(np.uint16), ins_path)\n        # final map for evaluation\n        mmcv.imwrite(final_map, final_path)\n        # depth\n        # if depth_final_rescale is not None:\n        #     mmcv.imwrite((depth_final_rescale * 256).astype(np.uint16), dep_path)\n        #  vis\n        # mmcv.imwrite(np.concatenate((vis_sem, vis_tracker), axis=0), vis_path)\n\n        for _ in range(batch_size):\n            prog_bar.update()\n\n    return results, pano_seg_2ch_list\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n             'the inference speed')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n             'useful when you want to format the result to a specific format and '\n             'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n             ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n             'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n             'in xxx=yyy format will be merged into config file. If the value to '\n             'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n             'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n             'Note that the quotation marks are necessary and that no white space '\n             'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function (deprecate), '\n             'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function')\n\n    parser.add_argument(\n        '--semantic',\n        action='store_true',\n        help=\"semantic input\"\n    )\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--output_dir', default=\"./work_dirs/vps/vps_output\",\n                        help='output result file in pickle format to load')\n    parser.add_argument('--n_video', type=int, default=50, help=\"number of video clips\")\n    parser.add_argument('--pan_im_json_file', type=str, default='data/cityscapes_vps/panoptic_im_val_city_vps.json')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n           or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n    print(args)\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    model = MMDataParallel(model, device_ids=[0])\n    # Inference the sequence\n    outputs, pred_pans_2ch = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                             args.show_score_thr, with_semantic_input=args.semantic)\n    print(\"==>Inference STEP Done!\")\n\n    # Evaluation Part\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/test_vps.py",
    "content": "import argparse\nimport os\nimport os.path as osp\nimport warnings\nimport numpy as np\nimport pickle\nimport json\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\nfrom external.test import encode_mask_results, tensor2imgs\n\n\ndef single_gpu_test(model,\n                    data_loader,\n                    show=False,\n                    out_dir=None,\n                    show_score_thr=0.3,\n                    with_semantic_input=False,\n                    rescale_depth=False,\n                    with_seq=False,\n                    ):\n    if out_dir is None:\n        out_dir = './out'\n    print(\"The output dir is {}\".format(out_dir))\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n\n    pano_seg_2ch_list = []\n    # print(\"data loader length:\", len(data_loader))\n    # exit()\n    for i, data in enumerate(data_loader):\n        seq_id = data['seq_id'][0].item()\n        img_id = data['img_id'][0].item()\n        data.pop('seq_id')\n        with torch.no_grad():\n            segm_results = model(return_loss=False, rescale=True, **data)\n\n        sseg_results, track_maps, _,  _, _ = segm_results\n        batch_size = 1\n\n        # dump results\n        seq_folder = str(seq_id) if with_seq else \"\"\n        cat_path = os.path.join(out_dir, 'panoptic', seq_folder, '{:06d}_{:06d}_cat.png'.format(seq_id, img_id))\n        ins_path = os.path.join(out_dir, 'panoptic', seq_folder, '{:06d}_{:06d}_ins.png'.format(seq_id, img_id))\n\n        mmcv.imwrite(sseg_results.astype(np.uint16), cat_path)\n        mmcv.imwrite(track_maps.astype(np.uint16), ins_path)\n        # if depth_final_rescale is not None:\n        #     mmcv.imwrite(((depth_final_rescale if rescale_depth else depth_final) * 256.).astype(np.uint16), dep_path)\n        # mmcv.imwrite(np.concatenate((vis_sem, vis_tracker), axis=0), vis_path)\n\n        for _ in range(batch_size):\n            prog_bar.update()\n\n    return results, pano_seg_2ch_list\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n             'the inference speed')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n             'useful when you want to format the result to a specific format and '\n             'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n             ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n             'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n             'in xxx=yyy format will be merged into config file. If the value to '\n             'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n             'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n             'Note that the quotation marks are necessary and that no white space '\n             'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function (deprecate), '\n             'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function')\n\n    parser.add_argument(\n        '--semantic',\n        action='store_true',\n        help=\"semantic input\"\n    )\n    parser.add_argument(\n        '--rescale-depth',\n        action='store_true',\n        help=\"\"\n    )\n    parser.add_argument(\n        '--with-seq',\n        action='store_true',\n        help=\"\"\n    )\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    # parser.add_argument('--output_dir', default=\"./work_dirs/vps/vps_output\",\n    #                     help='output result file in pickle format to load')\n    # parser.add_argument('--n_video', type=int, default=50, help=\"number of video clips\")\n    # parser.add_argument('--pan_im_json_file', type=str, default='data/cityscapes_vps/panoptic_im_val_city_vps.json')\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n           or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n    print(args)\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    model = MMDataParallel(model, device_ids=[0])\n    # Inference the sequence\n    outputs, pred_pans_2ch = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                             args.show_score_thr, with_semantic_input=args.semantic,\n                                             rescale_depth=args.rescale_depth, with_seq=args.with_seq)\n    print(\"==>Inference Depth VPS Done!\")\n\n    # Evaluation Part\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/train.py",
    "content": "import argparse\nimport copy\nimport os\nimport os.path as osp\nimport time\nimport warnings\n\nimport mmcv\nimport torch\nimport torch.distributed as dist\nfrom mmcv import Config, DictAction\nfrom mmcv.runner import get_dist_info, init_dist\nfrom mmcv.utils import get_git_hash\nfrom mmdet import __version__\nfrom mmdet.apis import set_random_seed\nfrom mmdet.datasets import build_dataset\nfrom mmdet.models import build_detector\nfrom mmdet.utils import collect_env, get_root_logger\n\nfrom external.train import train_detector\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a detector')\n    parser.add_argument('config', help='train config file path')\n    parser.add_argument('--work-dir', help='the dir to save logs and models')\n    parser.add_argument(\n        '--resume-from', help='the checkpoint file to resume from')\n    parser.add_argument(\n        '--load-from', help='the checkpoint file to resume from')\n    parser.add_argument(\n        '--no-validate',\n        action='store_true',\n        help='whether not to evaluate the checkpoint during training')\n    group_gpus = parser.add_mutually_exclusive_group()\n    group_gpus.add_argument(\n        '--gpus',\n        type=int,\n        help='number of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed training)')\n    parser.add_argument('--seed', type=int, default=None, help='random seed')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.')\n    parser.add_argument(\n        '--detect-anomaly',\n        action='store_true',\n        help='detect anomaly')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file (deprecate), '\n        'change to --cfg-options instead.')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.cfg_options:\n        raise ValueError(\n            '--options and --cfg-options cannot be both '\n            'specified, --options is deprecated in favor of --cfg-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --cfg-options')\n        args.cfg_options = args.options\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    # work_dir is determined in this priority: CLI > segment in file > filename\n    if args.work_dir is not None:\n        # update configs according to CLI args if args.work_dir is not None\n        cfg.work_dir = args.work_dir\n    elif cfg.get('work_dir', None) is None:\n        # use config filename as default work_dir if cfg.work_dir is None\n        cfg.work_dir = osp.join('./work_dirs',\n                                osp.splitext(osp.basename(args.config))[0])\n    if args.resume_from is not None:\n        cfg.resume_from = args.resume_from\n    if args.load_from is not None:\n        cfg.load_from = args.load_from\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n        dist.barrier()\n        # re-set gpu_ids with distributed training mode\n        _, world_size = get_dist_info()\n        cfg.gpu_ids = range(world_size)\n\n    # create work_dir\n    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n    # dump config\n    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))\n    # init the logger before other steps\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)\n\n    # Added in PansegMM\n    # Log the git hash info to video_knet_vis the experiments\n    logger.info('The repo is : https://github.com/lxtGH/PanopticSegMM/tree/{}/'.format(get_git_hash()))\n    logger.info('The config is : https://github.com/lxtGH/PanopticSegMM/tree/{}/{}'.format(get_git_hash(), args.config))\n\n    # init the meta dict to record some important information such as\n    # environment info and seed, which will be logged\n    meta = dict()\n    # log env info\n    env_info_dict = collect_env()\n    env_info = '\\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])\n    dash_line = '-' * 60 + '\\n'\n    logger.info('Environment info:\\n' + dash_line + env_info + '\\n' +\n                dash_line)\n    meta['env_info'] = env_info\n    meta['config'] = cfg.pretty_text\n    # log some basic info\n    logger.info(f'Distributed training: {distributed}')\n    logger.info(f'Config:\\n{cfg.pretty_text}')\n\n    # set random seeds\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, '\n                    f'deterministic: {args.deterministic}')\n        set_random_seed(args.seed, deterministic=args.deterministic)\n    else:\n        set_random_seed(0, deterministic=args.deterministic)\n    cfg.seed = args.seed\n    meta['seed'] = args.seed\n    meta['exp_name'] = osp.basename(args.config)\n\n    model = build_detector(\n        cfg.model,\n        train_cfg=cfg.get('train_cfg'),\n        test_cfg=cfg.get('test_cfg'))\n    model.init_weights()\n\n    logger.info(f'Model:\\n{model}')\n    datasets = [build_dataset(cfg.data.train)]\n    if len(cfg.workflow) == 2:\n        val_dataset = copy.deepcopy(cfg.data.val)\n        val_dataset.pipeline = cfg.data.train.pipeline\n        datasets.append(build_dataset(val_dataset))\n    if cfg.checkpoint_config is not None:\n        # save mmdet version, config file content and class names in\n        # checkpoints as meta data\n        cfg.checkpoint_config.meta = dict(\n            mmdet_version=__version__ + get_git_hash()[:7],\n            CLASSES=datasets[0].CLASSES)\n    # add an attribute for visualization convenience\n    model.CLASSES = datasets[0].CLASSES\n    if args.detect_anomaly:\n        with torch.autograd.detect_anomaly():\n            train_detector(\n                model,\n                datasets,\n                cfg,\n                distributed=distributed,\n                validate=(not args.no_validate),\n                timestamp=timestamp,\n                meta=meta)\n    else:\n        train_detector(\n            model,\n            datasets,\n            cfg,\n            distributed=distributed,\n            validate=(not args.no_validate),\n            timestamp=timestamp,\n            meta=meta)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/utils/DSTQ.py",
    "content": "from typing import Sequence, Tuple\nimport collections\n\nimport numpy as np\n\nfrom .STQ import STQuality\n\n\nclass DSTQuality(STQuality):\n    def __init__(\n            self,\n            num_classes: int,\n            things_list: Sequence[int],\n            ignore_label: int,\n            label_bit_shift: int,\n            offset: int,\n            depth_threshold: Tuple[float] = (1.25, 1.1),\n            name: str = 'dstq'\n    ):\n        super().__init__(\n            num_classes=num_classes,\n            things_list=things_list,\n            ignore_label=ignore_label,\n            label_bit_shift=label_bit_shift,\n            offset=offset\n        )\n        if not (isinstance(depth_threshold, tuple) or\n                isinstance(depth_threshold, list)):\n            raise TypeError('The type of depth_threshold must be tuple or list.')\n        if not depth_threshold:\n            raise ValueError('depth_threshold must be non-empty.')\n        self._depth_threshold = tuple(depth_threshold)\n        self._depth_total_counts = collections.OrderedDict()\n        self._depth_inlier_counts = []\n        for _ in range(len(self._depth_threshold)):\n            self._depth_inlier_counts.append(collections.OrderedDict())\n\n    def update_state(\n            self,\n            y_true: np.ndarray,\n            y_pred: np.ndarray,\n            d_true: np.ndarray,\n            d_pred: np.ndarray,\n            sequence_id: int = 0\n    ):\n        \"\"\"Accumulates the depth-aware segmentation and tracking quality statistics.\n        Args:\n          y_true: The ground-truth panoptic label map for a particular video frame\n            (defined as semantic_map * max_instances_per_category + instance_map).\n          y_pred: The predicted panoptic label map for a particular video frame\n            (defined as semantic_map * max_instances_per_category + instance_map).\n          d_true: The ground-truth depth map for this video frame.\n          d_pred: The predicted depth map for this video frame.\n          sequence_id: The optional ID of the sequence the frames belong to. When no\n            sequence is given, all frames are considered to belong to the same\n            sequence (default: 0).\n        \"\"\"\n        super().update_state(y_true, y_pred, sequence_id)\n        # Valid depth labels contain positive values.\n        d_valid_mask = d_true > 0\n        d_valid_total = np.sum(d_valid_mask.astype(int))\n        # Valid depth prediction is expected to contain positive values.\n        # TODO : very wrong implementation because it is hackable\n        d_valid_mask = np.logical_and(d_valid_mask, d_pred > 0)\n        d_valid_true = d_true[d_valid_mask]\n        d_valid_pred = d_pred[d_valid_mask]\n        inlier_error = np.maximum(d_valid_pred / d_valid_true,\n                                  d_valid_true / d_valid_pred)\n        # For each threshold, count the number of inliers.\n        for threshold_index, threshold in enumerate(self._depth_threshold):\n            num_inliers = np.sum((inlier_error <= threshold).astype(int))\n            inlier_counts = self._depth_inlier_counts[threshold_index]\n            inlier_counts[sequence_id] = (inlier_counts.get(sequence_id, 0) + int(num_inliers))\n        # Update the total counts of the depth labels.\n        self._depth_total_counts[sequence_id] = (\n                self._depth_total_counts.get(sequence_id, 0) + int(d_valid_total))\n\n    def result(self):\n        \"\"\"Computes the depth-aware segmentation and tracking quality.\n        Returns:\n          A dictionary containing:\n            - 'STQ': The total STQ score.\n            - 'AQ': The total association quality (AQ) score.\n            - 'IoU': The total mean IoU.\n            - 'STQ_per_seq': A list of the STQ score per sequence.\n            - 'AQ_per_seq': A list of the AQ score per sequence.\n            - 'IoU_per_seq': A list of mean IoU per sequence.\n            - 'Id_per_seq': A list of sequence Ids to map list index to sequence.\n            - 'Length_per_seq': A list of the length of each sequence.\n            - 'DSTQ': The total DSTQ score.\n            - 'DSTQ@thres': The total DSTQ score for threshold thres\n            - 'DSTQ_per_seq@thres': A list of DSTQ score per sequence for thres.\n            - 'DQ': The total DQ score.\n            - 'DQ@thres': The total DQ score for threshold thres.\n            - 'DQ_per_seq@thres': A list of DQ score per sequence for thres.\n        \"\"\"\n        # Gather the results for STQ.\n        stq_results = super().result()\n        # Collect results for depth quality per sequecne and threshold.\n        dq_per_seq_at_threshold = {}\n        dq_at_threshold = {}\n        for threshold_index, threshold in enumerate(self._depth_threshold):\n            dq_per_seq_at_threshold[threshold] = [0] * len(self._ground_truth)\n            total_count = 0\n            inlier_count = 0\n            # Follow the order of computing STQ by enumerating _ground_truth.\n            for index, sequence_id in enumerate(self._ground_truth):\n                sequence_inlier = self._depth_inlier_counts[threshold_index][sequence_id]\n                sequence_total = self._depth_total_counts[sequence_id]\n                if sequence_total > 0:\n                    dq_per_seq_at_threshold[threshold][\n                        index] = sequence_inlier / sequence_total\n                total_count += sequence_total\n                inlier_count += sequence_inlier\n            if total_count == 0:\n                dq_at_threshold[threshold] = 0\n            else:\n                dq_at_threshold[threshold] = inlier_count / total_count\n        # Compute DQ as the geometric mean of DQ's at different thresholds.\n        dq = 1\n        for _, threshold in enumerate(self._depth_threshold):\n            dq *= dq_at_threshold[threshold]\n        dq = dq ** (1 / len(self._depth_threshold))\n        dq_results = {}\n        dq_results['DQ'] = dq\n        for _, threshold in enumerate(self._depth_threshold):\n            dq_results['DQ@{}'.format(threshold)] = dq_at_threshold[threshold]\n            dq_results['DQ_per_seq@{}'.format(\n                threshold)] = dq_per_seq_at_threshold[threshold]\n        # Combine STQ and DQ to get DSTQ.\n        dstq_results = {}\n        dstq_results['DSTQ'] = (stq_results['STQ'] ** 2 * dq) ** (1 / 3)\n        for _, threshold in enumerate(self._depth_threshold):\n            dstq_results['DSTQ@{}'.format(threshold)] = (stq_results['STQ'] ** 2 * dq_at_threshold[\n                                                            threshold]) ** (1 / 3)\n            dstq_results['DSTQ_per_seq@{}'.format(threshold)] = [\n                (stq_result ** 2 * dq_result) ** (1 / 3) for stq_result, dq_result in zip(\n                    stq_results['STQ_per_seq'], dq_per_seq_at_threshold[threshold])\n            ]\n        # Merge all the results.\n        dstq_results.update(stq_results)\n        dstq_results.update(dq_results)\n        return dstq_results\n\n    def reset_states(self):\n        \"\"\"Resets all states that accumulated data.\"\"\"\n        super().reset_states()\n        self._depth_total_counts = collections.OrderedDict()\n        self._depth_inlier_counts = []\n        for _ in range(len(self._depth_threshold)):\n            self._depth_inlier_counts.append(collections.OrderedDict())\n"
  },
  {
    "path": "tools/utils/STQ.py",
    "content": "# This file is copied from deeplab2, please refer to https://github.com/google-research/deeplab2/\n# for details. Please cite their papers if this file is helpful.\n\n# coding=utf-8\n# Copyright 2021 The Deeplab2 Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Numpy Implementation of the Segmentation and Tracking Quality (STQ) metric.\nThis implementation is designed to work stand-alone. Please feel free to copy\nthis file and the corresponding unit-test to your project.\n\"\"\"\n\nimport collections\nfrom typing import Mapping, MutableMapping, Sequence, Text, Any\nimport numpy as np\n\n_EPSILON = 1e-15\n\n\ndef _update_dict_stats(stat_dict: MutableMapping[int, np.ndarray],\n                       id_array: np.ndarray):\n    \"\"\"Updates a given dict with corresponding counts.\"\"\"\n    ids, counts = np.unique(id_array, return_counts=True)\n    for idx, count in zip(ids, counts):\n        if idx in stat_dict:\n            stat_dict[idx] += count\n        else:\n            stat_dict[idx] = count\n\n\nclass STQuality(object):\n    \"\"\"Metric class for the Segmentation and Tracking Quality (STQ).\n    Please see the following paper for more details about the metric:\n    \"STEP: Segmenting and Tracking Every Pixel\", Weber et al., arXiv:2102.11859,\n    2021.\n    The metric computes the geometric mean of two terms.\n    - Association Quality: This term measures the quality of the video_knet_vis ID\n        assignment for `thing` classes. It is formulated as a weighted IoU\n        measure.\n    - Segmentation Quality: This term measures the semantic segmentation quality.\n        The standard class IoU measure is used for this.\n    Example usage:\n    stq_obj = segmentation_tracking_quality.STQuality(num_classes, things_list,\n      ignore_label, label_bit_shift, offset)\n    stq_obj.update_state(y_true_1, y_pred_1)\n    stq_obj.update_state(y_true_2, y_pred_2)\n    ...\n    result = stq_obj.result()\n    \"\"\"\n\n    def __init__(self, num_classes: int, things_list: Sequence[int],\n                 ignore_label: int, label_bit_shift: int, offset: int):\n        \"\"\"Initialization of the STQ metric.\n        Args:\n          num_classes: Number of classes in the dataset as an integer.\n          things_list: A sequence of class ids that belong to `things`.\n          ignore_label: The class id to be ignored in evaluation as an integer or\n            integer tensor.\n          label_bit_shift: The number of bits the class label is shifted as an\n            integer -> (class_label << bits) + trackingID\n          offset: The maximum number of unique labels as an integer or integer\n            tensor.\n        \"\"\"\n        self._num_classes = num_classes\n        self._ignore_label = ignore_label\n        self._things_list = things_list\n        self._label_bit_shift = label_bit_shift\n        self._bit_mask = (2 ** label_bit_shift) - 1\n\n        if ignore_label >= num_classes:\n            self._confusion_matrix_size = num_classes + 1\n            self._include_indices = np.arange(self._num_classes)\n        else:\n            self._confusion_matrix_size = num_classes\n            self._include_indices = np.array(\n                [i for i in range(num_classes) if i != self._ignore_label])\n\n        self._iou_confusion_matrix_per_sequence = collections.OrderedDict()\n        self._predictions = collections.OrderedDict()\n        self._ground_truth = collections.OrderedDict()\n        self._intersections = collections.OrderedDict()\n        self._sequence_length = collections.OrderedDict()\n        self._offset = offset\n        lower_bound = num_classes << self._label_bit_shift\n        if offset < lower_bound:\n            raise ValueError('The provided offset %d is too small. No guarantess '\n                             'about the correctness of the results can be made. '\n                             'Please choose an offset that is higher than num_classes'\n                             ' * max_instances_per_category = %d' % lower_bound)\n\n    def get_semantic(self, y: np.ndarray) -> np.ndarray:\n        \"\"\"Returns the semantic class from a panoptic label map.\"\"\"\n        return y >> self._label_bit_shift\n\n    def update_state(self, y_true: np.ndarray, y_pred: np.ndarray, sequence_id=0):\n        \"\"\"Accumulates the segmentation and tracking quality statistics.\n        IMPORTANT: When encoding the parameters y_true and y_pred, please be aware\n        that the `+` operator binds higher than the label shift `<<` operator.\n        Args:\n          y_true: The ground-truth panoptic label map for a particular video frame\n            (defined as (semantic_map << label_bit_shift) + instance_map).\n          y_pred: The predicted panoptic label map for a particular video frame\n            (defined as (semantic_map << label_bit_shift) + instance_map).\n          sequence_id: The optional ID of the sequence the frames belong to. When no\n            sequence is given, all frames are considered to belong to the same\n            sequence (default: 0).\n        \"\"\"\n        y_true = y_true.astype(np.int64)\n        y_pred = y_pred.astype(np.int64)\n\n        semantic_label = self.get_semantic(y_true)\n        semantic_prediction = self.get_semantic(y_pred)\n        # Check if the ignore value is outside the range [0, num_classes]. If yes,\n        # map `_ignore_label` to `_num_classes`, so it can be used to create the\n        # confusion matrix.\n        if self._ignore_label > self._num_classes:\n            semantic_label = np.where(semantic_label != self._ignore_label,\n                                      semantic_label, self._num_classes)\n            semantic_prediction = np.where(semantic_prediction != self._ignore_label,\n                                           semantic_prediction, self._num_classes)\n        if sequence_id in self._iou_confusion_matrix_per_sequence:\n            idxs = (np.reshape(semantic_label, [-1]) <<\n                    self._label_bit_shift) + np.reshape(semantic_prediction, [-1])\n            unique_idxs, counts = np.unique(idxs, return_counts=True)\n            self._iou_confusion_matrix_per_sequence[sequence_id][\n                unique_idxs >> self._label_bit_shift,\n                unique_idxs & self._bit_mask] += counts\n            self._sequence_length[sequence_id] += 1\n        else:\n            self._iou_confusion_matrix_per_sequence[sequence_id] = np.zeros(\n                (self._confusion_matrix_size, self._confusion_matrix_size),\n                dtype=np.int64)\n            idxs = np.stack([\n                np.reshape(semantic_label, [-1]),\n                np.reshape(semantic_prediction, [-1])\n            ],\n                axis=0)\n            np.add.at(self._iou_confusion_matrix_per_sequence[sequence_id],\n                      tuple(idxs), 1)\n\n            self._predictions[sequence_id] = {}\n            self._ground_truth[sequence_id] = {}\n            self._intersections[sequence_id] = {}\n            self._sequence_length[sequence_id] = 1\n\n        instance_label = y_true & self._bit_mask  # 0xFFFF == 2 ^ 16 - 1\n\n        label_mask = np.zeros_like(semantic_label, dtype=np.bool)\n        prediction_mask = np.zeros_like(semantic_prediction, dtype=np.bool)\n        for things_class_id in self._things_list:\n            label_mask = np.logical_or(label_mask, semantic_label == things_class_id)\n            prediction_mask = np.logical_or(prediction_mask,\n                                            semantic_prediction == things_class_id)\n\n        # Select the `crowd` region of the current class. This region is encoded\n        # instance id `0`.\n        is_crowd = np.logical_and(instance_label == 0, label_mask)\n        # Select the non-crowd region of the corresponding class as the `crowd`\n        # region is ignored for the tracking term.\n        label_mask = np.logical_and(label_mask, np.logical_not(is_crowd))\n        # Do not punish id assignment for regions that are annotated as `crowd` in\n        # the ground-truth.\n        prediction_mask = np.logical_and(prediction_mask, np.logical_not(is_crowd))\n\n        seq_preds = self._predictions[sequence_id]\n        seq_gts = self._ground_truth[sequence_id]\n        seq_intersects = self._intersections[sequence_id]\n\n        # Compute and update areas of ground-truth, predictions and intersections.\n        _update_dict_stats(seq_preds, y_pred[prediction_mask])\n        _update_dict_stats(seq_gts, y_true[label_mask])\n\n        non_crowd_intersection = np.logical_and(label_mask, prediction_mask)\n        intersection_ids = (\n                y_true[non_crowd_intersection] * self._offset +\n                y_pred[non_crowd_intersection])\n        _update_dict_stats(seq_intersects, intersection_ids)\n\n    def result(self) -> Mapping[Text, Any]:\n        \"\"\"Computes the segmentation and tracking quality.\n        Returns:\n          A dictionary containing:\n            - 'STQ': The total STQ score.\n            - 'AQ': The total association quality (AQ) score.\n            - 'IoU': The total mean IoU.\n            - 'STQ_per_seq': A list of the STQ score per sequence.\n            - 'AQ_per_seq': A list of the AQ score per sequence.\n            - 'IoU_per_seq': A list of mean IoU per sequence.\n            - 'Id_per_seq': A list of string-type sequence Ids to map list index to\n                sequence.\n            - 'Length_per_seq': A list of the length of each sequence.\n        \"\"\"\n        # Compute association quality (AQ)\n        num_tubes_per_seq = [0] * len(self._ground_truth)\n        aq_per_seq = [0] * len(self._ground_truth)\n        iou_per_seq = [0] * len(self._ground_truth)\n        id_per_seq = [''] * len(self._ground_truth)\n\n        for index, sequence_id in enumerate(self._ground_truth):\n            outer_sum = 0.0\n            predictions = self._predictions[sequence_id]\n            ground_truth = self._ground_truth[sequence_id]\n            intersections = self._intersections[sequence_id]\n            num_tubes_per_seq[index] = len(ground_truth)\n            id_per_seq[index] = sequence_id\n\n            for gt_id, gt_size in ground_truth.items():\n                inner_sum = 0.0\n                for pr_id, pr_size in predictions.items():\n                    tpa_key = self._offset * gt_id + pr_id\n                    if tpa_key in intersections:\n                        tpa = intersections[tpa_key]\n                        fpa = pr_size - tpa\n                        fna = gt_size - tpa\n                        inner_sum += tpa * (tpa / (tpa + fpa + fna))\n\n                outer_sum += 1.0 / gt_size * inner_sum\n            aq_per_seq[index] = outer_sum\n\n        aq_mean = np.sum(aq_per_seq) / np.maximum(\n            np.sum(num_tubes_per_seq), _EPSILON)\n        aq_per_seq = aq_per_seq / np.maximum(num_tubes_per_seq, _EPSILON)\n\n        # Compute IoU scores.\n        # The rows correspond to ground-truth and the columns to predictions.\n        # Remove fp from confusion matrix for the void/ignore class.\n        total_confusion = np.zeros(\n            (self._confusion_matrix_size, self._confusion_matrix_size),\n            dtype=np.int64)\n        for index, confusion in enumerate(\n                self._iou_confusion_matrix_per_sequence.values()):\n            removal_matrix = np.zeros_like(confusion)\n            removal_matrix[self._include_indices, :] = 1.0\n            confusion *= removal_matrix\n            total_confusion += confusion\n\n            # `intersections` corresponds to true positives.\n            intersections = confusion.diagonal()\n            fps = confusion.sum(axis=0) - intersections\n            fns = confusion.sum(axis=1) - intersections\n            unions = intersections + fps + fns\n\n            num_classes = np.count_nonzero(unions)\n            ious = (\n                    intersections.astype(np.double) /\n                    np.maximum(unions, 1e-15).astype(np.double))\n            iou_per_seq[index] = np.sum(ious) / num_classes\n\n        # `intersections` corresponds to true positives.\n        intersections = total_confusion.diagonal()\n        fps = total_confusion.sum(axis=0) - intersections\n        fns = total_confusion.sum(axis=1) - intersections\n        unions = intersections + fps + fns\n\n        num_classes = np.count_nonzero(unions)\n        ious = (\n                intersections.astype(np.double) /\n                np.maximum(unions, _EPSILON).astype(np.double))\n        iou_mean = np.sum(ious) / num_classes\n\n        st_quality = np.sqrt(aq_mean * iou_mean)\n        st_quality_per_seq = np.sqrt(aq_per_seq * iou_per_seq)\n        return {\n            'STQ': st_quality,\n            'AQ': aq_mean,\n            'IoU': float(iou_mean),\n            'STQ_per_seq': st_quality_per_seq,\n            'AQ_per_seq': aq_per_seq,\n            'IoU_per_seq': iou_per_seq,\n            'ID_per_seq': id_per_seq,\n            'Length_per_seq': list(self._sequence_length.values()),\n        }\n\n    def reset_states(self):\n        \"\"\"Resets all states that accumulated data.\"\"\"\n        self._iou_confusion_matrix_per_sequence = collections.OrderedDict()\n        self._predictions = collections.OrderedDict()\n        self._ground_truth = collections.OrderedDict()\n        self._intersections = collections.OrderedDict()\n        self._sequence_length = collections.OrderedDict()\n"
  },
  {
    "path": "tools/utils/cityscapesvps_eval.py",
    "content": "from __future__ import print_function\n\nimport argparse\nimport os\nimport os.path as osp\nimport torch.multiprocessing as multiprocessing\nimport numpy as np\nimport json\nfrom PIL import Image\nimport pickle\nfrom torch.utils.data import Dataset\n\n\nclass CityscapesVps(Dataset):\n\n    def __init__(self):\n\n        super(CityscapesVps, self).__init__()\n\n        self.nframes_per_video = 6\n        self.lambda_ = 5\n        self.labeled_fid = 20\n\n    def _save_image_single_core(self, proc_id, images_set, names_set, colors = None):\n\n        def colorize(gray, palette):\n            # gray: numpy array of the label and 1*3N size list palette\n            color = Image.fromarray(gray.astype(np.uint8)).convert('P')\n            color.putpalette(palette)\n            return color\n\n        for working_idx, (image, name) in enumerate(zip(images_set, names_set)):\n            if colors is not None:\n                image = colorize(image, colors)\n            else:\n                image = Image.fromarray(image)\n            os.makedirs(os.path.dirname(name), exist_ok=True)\n            image.save(name)\n\n    def inference_panoptic_video(self, pred_pans_2ch, output_dir,\n                                 categories,\n                                 names,\n                                 n_video=0):\n        from panopticapi.utils import IdGenerator\n\n        # Sample only frames with GT annotations.\n        if len(pred_pans_2ch) != len(names):\n            pred_pans_2ch = pred_pans_2ch[(self.labeled_fid // self.lambda_)::self.lambda_]\n        categories = {el['id']: el for el in categories}\n        color_generator = IdGenerator(categories)\n\n        def get_pred_large(pan_2ch_all, vid_num, nframes_per_video=6):\n            vid_num = len(pan_2ch_all) // nframes_per_video  # 10\n            cpu_num = multiprocessing.cpu_count() // 2  # 32 --> 16\n            nprocs = min(vid_num, cpu_num)  # 10\n            max_nframes = cpu_num * nframes_per_video\n            nsplits = (len(pan_2ch_all) - 1) // max_nframes + 1\n            annotations, pan_all = [], []\n            for i in range(0, len(pan_2ch_all), max_nframes):\n                print('==> Read and convert VPS output - split %d/%d' % ((i // max_nframes) + 1, nsplits))\n                pan_2ch_part = pan_2ch_all[i:min(\n                    i + max_nframes, len(pan_2ch_all))]\n                pan_2ch_split = np.array_split(pan_2ch_part, nprocs)\n                workers = multiprocessing.Pool(processes=nprocs)\n                processes = []\n                for proc_id, pan_2ch_set in enumerate(pan_2ch_split):\n                    p = workers.apply_async(\n                        self.converter_2ch_track_core,\n                        (proc_id, pan_2ch_set, color_generator))\n                    processes.append(p)\n                workers.close()\n                workers.join()\n\n                for p in processes:\n                    p = p.get()\n                    annotations.extend(p[0])\n                    pan_all.extend(p[1])\n\n            pan_json = {'annotations': annotations}\n            return pan_all, pan_json\n\n        def save_image(images, save_folder, names, colors=None):\n            os.makedirs(save_folder, exist_ok=True)\n\n            names = [osp.join(save_folder,\n                              name.replace('_leftImg8bit', '').replace('_newImg8bit', '').replace('jpg', 'png').replace(\n                                  'jpeg', 'png')) for name in names]\n            cpu_num = multiprocessing.cpu_count() // 2\n            images_split = np.array_split(images, cpu_num)\n            names_split = np.array_split(names, cpu_num)\n            workers = multiprocessing.Pool(processes=cpu_num)\n            for proc_id, (images_set, names_set) in enumerate(zip(images_split, names_split)):\n                workers.apply_async(self._save_image_single_core, (proc_id, images_set, names_set, colors))\n            workers.close()\n            workers.join()\n\n        # inference_panoptic_video\n        pred_pans, pred_json = get_pred_large(pred_pans_2ch,\n                                              vid_num=n_video)\n        print('--------------------------------------')\n        print('==> Saving VPS output png files')\n        os.makedirs(output_dir, exist_ok=True)\n        save_image(pred_pans_2ch, osp.join(output_dir, 'pan_2ch'), names)\n        save_image(pred_pans, osp.join(output_dir, 'pan_pred'), names)\n        print('==> Saving pred.jsons file')\n        json.dump(pred_json, open(osp.join(output_dir, 'pred.json'), 'w'))\n        print('--------------------------------------')\n\n        return pred_pans, pred_json\n\n    def converter_2ch_track_core(self, proc_id, pan_2ch_set, color_generator):\n        from panopticapi.utils import rgb2id\n\n        OFFSET = 1000\n        VOID = 255\n        annotations, pan_all = [], []\n        # reference dict to used color\n        inst2color = {}\n        for idx in range(len(pan_2ch_set)):\n            pan_2ch = np.uint32(pan_2ch_set[idx])\n            # pan_2ch: ss-seg maps[:,:,0], id-seg maps[:,:,1]\n            pan = OFFSET * pan_2ch[:, :, 0] + pan_2ch[:, :, 1]\n\n            pan_format = np.zeros((pan_2ch.shape[0], pan_2ch.shape[1], 3), dtype=np.uint8)\n            l = np.unique(pan)\n\n            segm_info = {}\n            for el in l:\n                sem = el // OFFSET\n\n                if sem == VOID:\n                    continue\n                mask = pan == el\n                #### handling used color for inst id\n                if el % OFFSET > 0:\n                    # if el > OFFSET:\n                    # things class\n                    if el in inst2color:\n                        color = inst2color[el]\n                    else:\n                        color = color_generator.get_color(sem)\n                        inst2color[el] = color\n                else:\n                    # stuff class\n                    color = color_generator.get_color(sem)\n\n                pan_format[mask] = color\n                index = np.where(mask)\n                x = index[1].min()\n                y = index[0].min()\n                width = index[1].max() - x\n                height = index[0].max() - y\n\n                dt = {\"category_id\": sem.item(), \"iscrowd\": 0, \"id\": int(rgb2id(color)),\n                      \"bbox\": [x.item(), y.item(), width.item(), height.item()], \"area\": mask.sum().item()}\n                segment_id = int(rgb2id(color))\n                segm_info[segment_id] = dt\n\n            # annotations.append({\"segments_info\": segm_info})\n            pan_all.append(pan_format)\n\n            gt_pan = np.uint32(pan_format)\n            # rgb2id for evaluation\n            pan_gt = gt_pan[:, :, 0] + gt_pan[:, :, 1] * 256 + gt_pan[:, :, 2] * 256 * 256\n            labels, labels_cnt = np.unique(pan_gt, return_counts=True)\n            for label, area in zip(labels, labels_cnt):\n                if label == 0:\n                    continue\n                if label not in segm_info.keys():\n                    print('label:', label)\n                    raise KeyError('label not in segm_info keys.')\n\n                segm_info[label][\"area\"] = int(area)\n            segm_info = [v for k, v in segm_info.items()]\n\n            annotations.append({\"segments_info\": segm_info})\n\n        return annotations, pan_all"
  },
  {
    "path": "tools/visualization.py",
    "content": "import argparse\nimport os\nimport os.path as osp\nimport warnings\nimport numpy as np\nimport pickle\nimport json\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\nfrom external.test import encode_mask_results, tensor2imgs\n\n\ndef single_gpu_test(model,\n                    data_loader,\n                    out_dir=None,\n                    ):\n    if out_dir is None:\n        out_dir = 'logger/blackhole'\n    print(\"The output dir is {}\".format(out_dir))\n    model.eval()\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            visualizations = model(return_loss=False, rescale=True, **data)\n\n        instance_map = visualizations['instance_map']\n        seg_infos = visualizations['segments_info']\n        depth = visualizations['depth_final']\n        prog_bar.update()\n    return None\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n             'the inference speed')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n             'useful when you want to format the result to a specific format and '\n             'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n             ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n             'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n             'in xxx=yyy format will be merged into config file. If the value to '\n             'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n             'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n             'Note that the quotation marks are necessary and that no white space '\n             'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function (deprecate), '\n             'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n             'format will be kwargs for dataset.evaluate() function')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n    print(args)\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu', strict=True)\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    model = MMDataParallel(model, device_ids=[0])\n    # Inference the sequence\n    single_gpu_test(model, data_loader, args.show_dir)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools_vis/apis/__init__.py",
    "content": "from .test import single_gpu_test, multi_gpu_test"
  },
  {
    "path": "tools_vis/apis/test.py",
    "content": "# Modified from mmdet 2.20.0 / https://github.com/open-mmlab/mmdetection/tree/ff9bc\n\nimport os.path as osp\nimport pickle\nimport shutil\nimport tempfile\nimport time\n\nimport mmcv\nimport torch\nimport torch.distributed as dist\nfrom mmcv.image import tensor2imgs\nfrom mmcv.runner import get_dist_info\n\nfrom mmdet.core import encode_mask_results\n\n\ndef single_gpu_test(model,\n                    data_loader,\n                    show=False,\n                    out_dir=None,\n                    show_score_thr=0.3):\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            result = model(return_loss=False, rescale=True, **data)\n\n        batch_size = len(result)\n\n        # encode mask results\n        for idx in range(len(result)):\n            if isinstance(result[idx][0], tuple):\n                result[idx] = [(bbox_results, encode_mask_results(mask_results))\n                               for bbox_results, mask_results in result[idx]]\n\n        results.extend(result)\n\n        for _ in range(batch_size):\n            prog_bar.update()\n\n    results = sum(results, [])\n    return results\n\n\ndef multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):\n    \"\"\"Test model with multiple gpus.\n\n    This method tests model with multiple gpus and collects the results\n    under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'\n    it encodes results to gpu tensors and use gpu communication for results\n    collection. On cpu mode it saves the results on different gpus to 'tmpdir'\n    and collects them by the rank 0 worker.\n\n    Args:\n        model (nn.Module): Model to be tested.\n        data_loader (nn.Dataloader): Pytorch data loader.\n        tmpdir (str): Path of directory to save the temporary results from\n            different gpus under cpu mode.\n        gpu_collect (bool): Option to use either gpu or cpu to collect results.\n\n    Returns:\n        list: The prediction results.\n    \"\"\"\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    rank, world_size = get_dist_info()\n    if rank == 0:\n        prog_bar = mmcv.ProgressBar(len(dataset))\n    time.sleep(2)  # This line can prevent deadlock problem in some cases.\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            result = model(return_loss=False, rescale=True, **data)\n            # encode mask results\n            for idx in range(len(result)):\n                if isinstance(result[idx][0], tuple):\n                    result[idx] = [(bbox_results, encode_mask_results(mask_results))\n                              for bbox_results, mask_results in result[idx]]\n        results.extend(result)\n\n        if rank == 0:\n            batch_size = len(result)\n            for _ in range(batch_size * world_size):\n                prog_bar.update()\n\n    # collect results from all ranks\n    if gpu_collect:\n        results = collect_results_gpu(results, size=len(dataset))\n    else:\n        results = collect_results_cpu(results, size=len(dataset), tmpdir=tmpdir)\n    if rank == 0:\n        results = sum(results, [])\n    return results\n\n\ndef collect_results_cpu(result_part, size, tmpdir=None):\n    rank, world_size = get_dist_info()\n    # create a tmp dir if it is not specified\n    if tmpdir is None:\n        MAX_LEN = 512\n        # 32 is whitespace\n        dir_tensor = torch.full((MAX_LEN, ),\n                                32,\n                                dtype=torch.uint8,\n                                device='cuda')\n        if rank == 0:\n            mmcv.mkdir_or_exist('.dist_test')\n            tmpdir = tempfile.mkdtemp(dir='.dist_test')\n            tmpdir = torch.tensor(\n                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')\n            dir_tensor[:len(tmpdir)] = tmpdir\n        dist.broadcast(dir_tensor, 0)\n        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()\n    else:\n        mmcv.mkdir_or_exist(tmpdir)\n    # dump the part result to the dir\n    mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))\n    dist.barrier()\n    # collect all parts\n    if rank != 0:\n        return None\n    else:\n        # load results of all parts from tmp dir\n        part_list = []\n        for i in range(world_size):\n            part_file = osp.join(tmpdir, f'part_{i}.pkl')\n            part_list.append(mmcv.load(part_file))\n        # sort the results\n        ordered_results = []\n        for res in zip(*part_list):\n            ordered_results.extend(list(res))\n        # the dataloader may pad some samples\n        ordered_results = ordered_results[:size]\n        # remove tmp dir\n        shutil.rmtree(tmpdir)\n        return ordered_results\n\n\ndef collect_results_gpu(result_part, size):\n    rank, world_size = get_dist_info()\n    # dump result part to tensor with pickle\n    part_tensor = torch.tensor(\n        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')\n    # gather all result part tensor shape\n    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')\n    shape_list = [shape_tensor.clone() for _ in range(world_size)]\n    dist.all_gather(shape_list, shape_tensor)\n    # padding result part tensor to max length\n    shape_max = torch.tensor(shape_list).max()\n    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')\n    part_send[:shape_tensor[0]] = part_tensor\n    part_recv_list = [\n        part_tensor.new_zeros(shape_max) for _ in range(world_size)\n    ]\n    # gather all result part\n    dist.all_gather(part_recv_list, part_send)\n\n    if rank == 0:\n        part_list = []\n        for recv, shape in zip(part_recv_list, shape_list):\n            part_list.append(\n                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))\n        # sort the results\n        ordered_results = []\n        for res in zip(*part_list):\n            ordered_results.extend(list(res))\n        # the dataloader may pad some samples\n        ordered_results = ordered_results[:size]\n        return ordered_results\n"
  },
  {
    "path": "tools_vis/dist_test_whole_video.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=${PORT:-$((29500 + $RANDOM % 29))}\n\nif command -v torchrun &> /dev/null\nthen\n  echo \"Using torchrun mode.\"\n  PYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\n    torchrun --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/test_whole_video.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}\nelse\n  echo \"Using launch mode.\"\n  PYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\n    python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/test_whole_video.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}\nfi\n"
  },
  {
    "path": "tools_vis/docker.sh",
    "content": "#!/bin/bash\n\nDATALOC=${DATALOC:-`realpath ../datasets`}\nLOGLOC=${LOGLOC:-`realpath ../logger`}\nIMG=${IMG:-\"harbory/openmmlab:eccv-2022\"}\n\ndocker run --gpus all -it --rm --ipc=host --net=host \\\n  --mount src=$(pwd),target=/data,type=bind \\\n  --mount src=$DATALOC,target=/data/data,type=bind \\\n  --mount src=$LOGLOC,target=/data/logger,type=bind \\\n  $IMG\n"
  },
  {
    "path": "tools_vis/slurm_test_vis.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-8}\nGPUS_PER_NODE=${GPUS_PER_NODE:-8}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u tools2/test_whole_video.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}"
  },
  {
    "path": "tools_vis/test.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Modified from mmdet 2.20.0 / https://github.com/open-mmlab/mmdetection/tree/ff9bc\nimport argparse\nimport os\nimport os.path as osp\nimport time\nimport warnings\n\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\n\nfrom mmdet.apis import multi_gpu_test, single_gpu_test\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument(\n        '--work-dir',\n        help='the directory to save the file containing evaluation metrics')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n        'the inference speed')\n    parser.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed testing)')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n        ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function (deprecate), '\n        'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n        if len(cfg.gpu_ids) > 1:\n            warnings.warn(\n                f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '\n                f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '\n                'non-distribute testing time.')\n            cfg.gpu_ids = cfg.gpu_ids[0:1]\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    rank, _ = get_dist_info()\n    # allows not to create\n    if args.work_dir is not None and rank == 0:\n        mmcv.mkdir_or_exist(osp.abspath(args.work_dir))\n        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n        json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=cfg.gpu_ids)\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  args.show_score_thr)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir,\n                                 args.gpu_collect)\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            eval_kwargs = cfg.get('evaluation', {}).copy()\n            # hard-code way to remove EvalHook args\n            for key in [\n                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',\n                    'rule', 'dynamic_intervals'\n            ]:\n                eval_kwargs.pop(key, None)\n            eval_kwargs.update(dict(metric=args.eval, **kwargs))\n            metric = dataset.evaluate(outputs, **eval_kwargs)\n            print(metric)\n            metric_dict = dict(config=args.config, metric=metric)\n            if args.work_dir is not None and rank == 0:\n                mmcv.dump(metric_dict, json_file)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools_vis/test_whole_video.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Modified from mmdet 2.20.0 / https://github.com/open-mmlab/mmdetection/tree/ff9bc\nimport argparse\nimport os\nimport os.path as osp\nimport time\nimport warnings\n\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\n\nfrom tools2.apis import multi_gpu_test, single_gpu_test\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument(\n        '--work-dir',\n        help='the directory to save the file containing evaluation metrics')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n        'the inference speed')\n    parser.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed testing)')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n        ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function (deprecate), '\n        'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    # if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n    #     raise ValueError('The output file must be a pkl file.')\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n        if len(cfg.gpu_ids) > 1:\n            warnings.warn(\n                f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '\n                f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '\n                'non-distribute testing time.')\n            cfg.gpu_ids = cfg.gpu_ids[0:1]\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    rank, _ = get_dist_info()\n    # allows not to create\n    if args.work_dir is not None and rank == 0:\n        mmcv.mkdir_or_exist(osp.abspath(args.work_dir))\n        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n        json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=cfg.gpu_ids)\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  args.show_score_thr)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir,\n                                 args.gpu_collect)\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        kwargs['resfile_path'] = args.checkpoint.replace('.pth', '_results')\n        if kwargs['resfile_path'][:7] == 'logger/':\n            os.system(\"ln -sf {} {}\".format(\n                os.path.join('../', kwargs['resfile_path'], 'submission_file.zip'), 'logger/submission.zip'))\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            eval_kwargs = cfg.get('evaluation', {}).copy()\n            # hard-code way to remove EvalHook args\n            for key in [\n                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',\n                    'rule', 'dynamic_intervals'\n            ]:\n                eval_kwargs.pop(key, None)\n            eval_kwargs.update(dict(metric=args.eval, **kwargs))\n            metric = dataset.evaluate(outputs, **eval_kwargs)\n            print(metric)\n            metric_dict = dict(config=args.config, metric=metric)\n            if args.work_dir is not None and rank == 0:\n                mmcv.dump(metric_dict, json_file)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "unitrack/__init__.py",
    "content": "from .model import *"
  },
  {
    "path": "unitrack/basetrack.py",
    "content": "import numpy as np\nfrom collections import OrderedDict,deque\nfrom unitrack.core.motion.kalman_filter import KalmanFilter\nimport unitrack.core.association.matching as matching\nfrom unitrack.utils.box import *\nimport torch\nimport torch.nn.functional as F\n\n\nclass TrackState(object):\n    New = 0\n    Tracked = 1\n    Lost = 2\n    Removed = 3\n\n\nclass BaseTrack(object):\n    _count = 0\n\n    track_id = 0\n    is_activated = False\n    state = TrackState.New\n\n    history = OrderedDict()\n    features = []\n    curr_feature = None\n    score = 0\n    start_frame = 0\n    frame_id = 0\n    time_since_update = 0\n\n    # multi-camera\n    location = (np.inf, np.inf)\n\n    @property\n    def end_frame(self):\n        return self.frame_id\n\n    @staticmethod\n    def next_id():\n        BaseTrack._count += 1\n        return BaseTrack._count\n\n    def activate(self, *args):\n        raise NotImplementedError\n\n    def predict(self):\n        raise NotImplementedError\n\n    def update(self, *args, **kwargs):\n        raise NotImplementedError\n\n    def mark_lost(self):\n        self.state = TrackState.Lost\n\n    def mark_removed(self):\n        self.state = TrackState.Removed\n\n\nclass STrack(BaseTrack):\n    shared_kalman = KalmanFilter()\n\n    def __init__(self, tlwh, score, temp_feat, buffer_size=30, \n            mask=None, pose=None, ac=False, category=-1, use_kalman=True):\n\n        # wait activate\n        self._tlwh = np.asarray(tlwh, dtype=np.float)\n        self.kalman_filter = None\n        self.mean, self.covariance = None, None\n        self.use_kalman = use_kalman\n        if not use_kalman: ac=True\n        self.is_activated = ac \n\n        self.score = score\n        self.category = category \n        self.tracklet_len = 0\n\n        self.smooth_feat = None\n        self.update_features(temp_feat)\n        self.features = deque([], maxlen=buffer_size)\n        self.alpha = 0.9\n        self.mask = mask\n        self.pose = pose\n    \n    def update_features(self, feat):\n        self.curr_feat = feat \n        if self.smooth_feat is None:\n            self.smooth_feat = feat\n        elif self.smooth_feat.shape == feat.shape:\n            self.smooth_feat = self.alpha *self.smooth_feat + (1-self.alpha) * feat\n        else:\n            pass\n\n\n    def predict(self):\n        mean_state = self.mean.copy()\n        if self.state != TrackState.Tracked:\n            mean_state[7] = 0\n        self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)\n\n    @staticmethod\n    def multi_predict(stracks):\n        if len(stracks) > 0:\n            multi_mean = np.asarray([st.mean.copy() for st in stracks])\n            multi_covariance = np.asarray([st.covariance for st in stracks])\n            for i,st in enumerate(stracks):\n                if st.state != TrackState.Tracked:\n                    multi_mean[i][7] = 0\n            multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)\n            for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n                stracks[i].mean = mean\n                stracks[i].covariance = cov\n\n\n    def activate(self, kalman_filter, frame_id):\n        \"\"\"Start a new tracklet\"\"\"\n        self.kalman_filter = kalman_filter\n        self.track_id = self.next_id()\n        self.mean, self.covariance = self.kalman_filter.initiate(tlwh_to_xyah(self._tlwh))\n\n        self.tracklet_len = 0\n        self.state = TrackState.Tracked\n        if frame_id == 1:\n            self.is_activated = True\n        #self.is_activated = True\n        self.frame_id = frame_id\n        self.start_frame = frame_id\n\n    def re_activate(self, new_track, frame_id, new_id=False, update_feature=True):\n        if self.use_kalman:\n            self.mean, self.covariance = self.kalman_filter.update(\n                self.mean, self.covariance, tlwh_to_xyah(new_track.tlwh)\n            )\n        else:\n            self.mean, self.covariance = None, None\n            self._tlwh = np.asarray(new_track.tlwh, dtype=np.float)\n        if update_feature:\n            self.update_features(new_track.curr_feat)\n        self.tracklet_len = 0\n        self.state = TrackState.Tracked\n        self.is_activated = True\n        self.frame_id = frame_id\n        if new_id:\n            self.track_id = self.next_id()\n        if not new_track.mask is None:\n            self.mask = new_track.mask\n\n    def update(self, new_track, frame_id, update_feature=True):\n        \"\"\"\n        Update a matched track\n        :type new_track: STrack\n        :type frame_id: int\n        :type update_feature: bool\n        :return:\n        \"\"\"\n        self.frame_id = frame_id\n        self.tracklet_len += 1\n\n        new_tlwh = new_track.tlwh\n        if self.use_kalman:\n            self.mean, self.covariance = self.kalman_filter.update(\n                self.mean, self.covariance, tlwh_to_xyah(new_tlwh))\n        else:\n            self.mean, self.covariance = None, None\n            self._tlwh = np.asarray(new_tlwh, dtype=np.float)\n        self.state = TrackState.Tracked\n        self.is_activated = True\n\n        self.score = new_track.score\n        '''\n        For TAO dataset \n        '''\n        self.category = new_track.category\n        if update_feature:\n            self.update_features(new_track.curr_feat)\n        if not new_track.mask is None:\n            self.mask = new_track.mask\n        if not new_track.pose is None:\n            self.pose = new_track.pose\n\n    @property\n    def tlwh(self):\n        \"\"\"Get current position in bounding box format `(top left x, top left y,\n                width, height)`.\n        \"\"\"\n        if self.mean is None:\n            return self._tlwh.copy()\n        ret = self.mean[:4].copy()\n        ret[2] *= ret[3]\n        ret[:2] -= ret[2:] / 2\n        return ret\n\n    @property\n    def tlbr(self):\n        \"\"\"Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,\n        `(top left, bottom right)`.\n        \"\"\"\n        ret = self.tlwh.copy()\n        ret[2:] += ret[:2]\n        return ret\n\n\n    def to_xyah(self):\n        return tlwh_to_xyah(self.tlwh)\n    \n\n    def __repr__(self):\n        return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)\n\n\ndef joint_stracks(tlista, tlistb):\n    exists = {}\n    res = []\n    for t in tlista:\n        exists[t.track_id] = 1\n        res.append(t)\n    for t in tlistb:\n        tid = t.track_id\n        if not exists.get(tid, 0):\n            exists[tid] = 1\n            res.append(t)\n    return res\n\n\ndef sub_stracks(tlista, tlistb):\n    stracks = {}\n    for t in tlista:\n        stracks[t.track_id] = t\n    for t in tlistb:\n        tid = t.track_id\n        if stracks.get(tid, 0):\n            del stracks[tid]\n    return list(stracks.values())\n\n\ndef remove_duplicate_stracks(stracksa, stracksb, ioudist=0.15):\n    pdist = matching.iou_distance(stracksa, stracksb)\n    pairs = np.where(pdist<ioudist)\n    dupa, dupb = list(), list()\n    for p,q in zip(*pairs):\n        timep = stracksa[p].frame_id - stracksa[p].start_frame\n        timeq = stracksb[q].frame_id - stracksb[q].start_frame\n        if timep > timeq:\n            dupb.append(q)\n        else:\n            dupa.append(p)\n    resa = [t for i,t in enumerate(stracksa) if not i in dupa]\n    resb = [t for i,t in enumerate(stracksb) if not i in dupb]\n    return resa, resb\n            \n\n"
  },
  {
    "path": "unitrack/box.py",
    "content": "###################################################################\n# File Name: box.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Fri Jan 29 15:16:53 2021\n###################################################################\n\nimport torch\nfrom torchvision import ops\n\nfrom .basetrack import STrack\nfrom .multitracker import AssociationTracker\nfrom unitrack.utils.box import scale_box, scale_box_input_size, xywh2xyxy, tlbr_to_tlwh\n\n\nclass BoxAssociationTracker(AssociationTracker):\n    def __init__(self, opt):\n        super(BoxAssociationTracker, self).__init__(opt)\n\n    def extract_emb(self, img, obs):\n        feat = self.app_model(img.unsqueeze(0).to(self.opt.device).float())\n        scale = [feat.shape[-1]/self.opt.img_size[0],\n                 feat.shape[-2]/self.opt.img_size[1]]\n        obs_feat = scale_box(scale, obs).to(self.opt.device)\n        obs_feat = [obs_feat[:, :4], ]\n        ret = ops.roi_align(feat, obs_feat, self.opt.feat_size).detach().cpu()\n        return ret\n\n    def prepare_obs(self, img, img0, obs, embs=None):\n        obs = torch.from_numpy(obs[obs[:, 4] > self.opt.conf_thres]).float()\n        if len(obs) > 0:\n            obs = xywh2xyxy(obs)\n            obs = scale_box(self.opt.img_size, obs)\n            embs = self.extract_emb(img, obs)\n            obs = scale_box_input_size(self.opt.img_size, obs, img0.shape)\n\n            if obs.shape[1] == 5:\n                detections = [STrack(tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f,\n                              self.buffer_size, use_kalman=self.opt.use_kalman)\n                              for (tlbrs, f) in zip(obs, embs)]\n            elif obs.shape[1] == 6:\n                detections = [STrack(tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f,\n                              self.buffer_size, category=tlbrs[5],\n                              use_kalman=self.opt.use_kalman)\n                              for (tlbrs, f) in zip(obs, embs)]\n            else:\n                raise ValueError(\n                        'Shape of observations should be [n, 5] or [n, 6].')\n        else:\n            detections = []\n        return detections\n"
  },
  {
    "path": "unitrack/core/__init__.py",
    "content": ""
  },
  {
    "path": "unitrack/core/association/__init__.py",
    "content": ""
  },
  {
    "path": "unitrack/core/association/matching.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport scipy\nfrom scipy.spatial.distance import cdist\nimport lap\n\nfrom cython_bbox import bbox_overlaps as bbox_ious\nfrom ..motion import kalman_filter\n\n\ndef merge_matches(m1, m2, shape):\n    O,P,Q = shape\n    m1 = np.asarray(m1)\n    m2 = np.asarray(m2)\n\n    M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))\n    M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))\n\n    mask = M1*M2\n    match = mask.nonzero()\n    match = list(zip(match[0], match[1]))\n    unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))\n    unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))\n\n    return match, unmatched_O, unmatched_Q\n\n\ndef linear_assignment(cost_matrix, thresh):\n    if cost_matrix.size == 0:\n        return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))\n    matches, unmatched_a, unmatched_b = [], [], []\n    cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)\n    for ix, mx in enumerate(x):\n        if mx >= 0:\n            matches.append([ix, mx])\n    unmatched_a = np.where(x < 0)[0]\n    unmatched_b = np.where(y < 0)[0]\n    matches = np.asarray(matches)\n    return matches, unmatched_a, unmatched_b\n            \n\ndef ious(atlbrs, btlbrs):\n    \"\"\"\n    Compute cost based on IoU\n    :type atlbrs: list[tlbr] | np.ndarray\n    :type atlbrs: list[tlbr] | np.ndarray\n\n    :rtype ious np.ndarray\n    \"\"\"\n    ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)\n    if ious.size == 0:\n        return ious\n\n    ious = bbox_ious(\n        np.ascontiguousarray(atlbrs, dtype=np.float),\n        np.ascontiguousarray(btlbrs, dtype=np.float)\n    )\n\n    return ious\n\n\ndef iou_distance(atracks, btracks):\n    \"\"\"\n    Compute cost based on IoU\n    :type atracks: list[STrack]\n    :type btracks: list[STrack]\n\n    :rtype cost_matrix np.ndarray\n    \"\"\"\n\n    if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):\n        atlbrs = atracks\n        btlbrs = btracks\n    else:\n        atlbrs = [track.tlbr for track in atracks]\n        btlbrs = [track.tlbr for track in btracks]\n    _ious = ious(atlbrs, btlbrs)\n    cost_matrix = 1 - _ious\n\n    return cost_matrix\n\ndef embedding_distance(tracks, detections, metric='cosine'):\n    \"\"\"\n    :param tracks: list[STrack]\n    :param detections: list[BaseTrack]\n    :param metric:\n    :return: cost_matrix np.ndarray\n    \"\"\"\n\n    cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)\n    if cost_matrix.size == 0:\n        return cost_matrix\n    det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)\n    track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)\n    cost_matrix = np.maximum(0.0, cdist(track_features, det_features)) # Nomalized features\n    return cost_matrix\n\n\ndef fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98, gate=True):\n    if cost_matrix.size == 0:\n        return cost_matrix\n    gating_dim = 2 if only_position else 4\n    gating_threshold = kalman_filter.chi2inv95[gating_dim]\n    measurements = np.asarray([det.to_xyah() for det in detections])\n    for row, track in enumerate(tracks):\n        gating_distance = kf.gating_distance(\n            track.mean, track.covariance, measurements, only_position, metric='maha')\n        if gate:\n            cost_matrix[row, gating_distance > gating_threshold] = np.inf\n        cost_matrix[row] = lambda_ * cost_matrix[row] + (1-lambda_)* gating_distance\n    return cost_matrix\n\n\ndef center_emb_distance(tracks, detections, metric='cosine'):\n    \"\"\"\n    :param tracks: list[STrack]\n    :param detections: list[BaseTrack]\n    :param metric:\n    :return: cost_matrix np.ndarray\n    \"\"\"\n\n    cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)\n    if cost_matrix.size == 0:\n        return cost_matrix\n    det_features = torch.stack([track.curr_feat.squeeze() for track in detections])\n    track_features = torch.stack([track.smooth_feat.squeeze() for track in tracks])\n    normed_det = F.normalize(det_features)\n    normed_track = F.normalize(track_features)\n    cost_matrix = torch.mm(normed_track, normed_det.T)\n    cost_matrix = 1 - cost_matrix.detach().cpu().numpy()\n    return cost_matrix\n\ndef recons_distance(tracks, detections, tmp=100):\n    \"\"\"\n    :param tracks: list[STrack]\n    :param detections: list[BaseTrack]\n    :param metric:\n    :return: cost_matrix np.ndarray\n    \"\"\"\n\n    cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)\n    if cost_matrix.size == 0:\n        return cost_matrix\n    det_features_ = torch.stack([track.curr_feat.squeeze() for track in detections])\n    track_features_ = torch.stack([track.smooth_feat for track in tracks])\n    det_features = F.normalize(det_features_, dim=1)\n    track_features = F.normalize(track_features_, dim=1)\n\n    ndet, ndim, nw, nh = det_features.shape\n    ntrk, _, _, _ = track_features.shape\n    fdet = det_features.permute(0,2,3,1).reshape(-1, ndim).cuda()        # ndet*nw*nh, ndim\n    ftrk = track_features.permute(0,2,3,1).reshape(-1, ndim).cuda()      # ntrk*nw*nh, ndim\n\n    aff = torch.mm(ftrk, fdet.transpose(0,1))                             # ntrk*nw*nh, ndet*nw*nh\n    aff_td = F.softmax(tmp*aff, dim=1)\n    aff_dt = F.softmax(tmp*aff, dim=0).transpose(0,1)\n\n    recons_ftrk = torch.einsum('tds,dsm->tdm', aff_td.view(ntrk*nw*nh, ndet, nw*nh), \n                                fdet.view(ndet, nw*nh, ndim))         # ntrk*nw*nh, ndet, ndim\n    recons_fdet = torch.einsum('dts,tsm->dtm', aff_dt.view(ndet*nw*nh, ntrk, nw*nh),\n                                ftrk.view(ntrk, nw*nh, ndim))         # ndet*nw*nh, ntrk, ndim\n \n    res_ftrk = (recons_ftrk.permute(0,2,1) - ftrk.unsqueeze(-1)).view(ntrk, nw*nh*ndim, ndet)\n    res_fdet = (recons_fdet.permute(0,2,1) - fdet.unsqueeze(-1)).view(ndet, nw*nh*ndim, ntrk)\n\n    cost_matrix = (torch.abs(res_ftrk).mean(1) + torch.abs(res_fdet).mean(1).transpose(0,1)) * 0.5\n    cost_matrix = cost_matrix / cost_matrix.max(1)[0].unsqueeze(-1) \n    #pdb.set_trace()\n    cost_matrix = cost_matrix.cpu().numpy()\n    return cost_matrix\n\n\ndef get_track_feat(tracks, feat_flag='curr'):\n    if feat_flag == 'curr':\n        feat_list = [track.curr_feat.squeeze(0) for track in tracks]\n    elif feat_flag == 'smooth':\n        feat_list = [track.smooth_feat.squeeze(0) for track in tracks]\n    else:\n        raise NotImplementedError\n    \n    n = len(tracks)\n    fdim = feat_list[0].shape[0]\n    fdim_num = len(feat_list[0].shape)\n    if fdim_num > 2:\n        feat_list = [f.view(fdim,-1) for f in feat_list]\n    numels = [f.shape[1] for f in feat_list]\n    \n    ret = torch.zeros(n, fdim, np.max(numels)).to(feat_list[0].device)\n    for i, f in enumerate(feat_list):\n        ret[i, :, :numels[i]] = f\n    return ret \n\ndef reconsdot_distance(tracks, detections, tmp=100):\n    \"\"\"\n    :param tracks: list[STrack]\n    :param detections: list[BaseTrack]\n    :param metric:\n    :return: cost_matrix np.ndarray\n    \"\"\"\n    cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)\n    if cost_matrix.size == 0:\n        return cost_matrix, None\n    det_features_ = get_track_feat(detections)\n    track_features_ = get_track_feat(tracks, feat_flag='curr')\n\n    det_features = F.normalize(det_features_, dim=1)\n    track_features = F.normalize(track_features_, dim=1)\n\n    ndet, ndim, nsd = det_features.shape\n    ntrk, _, nst = track_features.shape\n\n    fdet = det_features.permute(0, 2, 1).reshape(-1, ndim)\n    ftrk = track_features.permute(0, 2, 1).reshape(-1, ndim)\n\n    aff = torch.mm(ftrk, fdet.transpose(0, 1))\n    aff_td = F.softmax(tmp*aff, dim=1)\n    aff_dt = F.softmax(tmp*aff, dim=0).transpose(0, 1)\n\n    recons_ftrk = torch.einsum('tds,dsm->tdm', aff_td.view(ntrk*nst, ndet, nsd),\n                               fdet.view(ndet, nsd, ndim))\n    recons_fdet = torch.einsum('dts,tsm->dtm', aff_dt.view(ndet*nsd, ntrk, nst),\n                               ftrk.view(ntrk, nst, ndim))\n\n    recons_ftrk = recons_ftrk.permute(0, 2, 1).reshape((ntrk, nst*ndim, ndet))\n    recons_ftrk_norm = F.normalize(recons_ftrk, dim=1)\n    recons_fdet = recons_fdet.permute(0, 2, 1).view(ndet, nsd*ndim, ntrk)\n    recons_fdet_norm = F.normalize(recons_fdet, dim=1)\n\n    dot_td = torch.einsum('tad,ta->td', recons_ftrk_norm,\n                          F.normalize(ftrk.reshape(ntrk, nst*ndim), dim=1))\n    dot_dt = torch.einsum('dat,da->dt', recons_fdet_norm,\n                          F.normalize(fdet.reshape(ndet, nsd*ndim), dim=1))\n\n    cost_matrix = 1 - 0.5 * (dot_td + dot_dt.transpose(0, 1))\n    cost_matrix = cost_matrix.detach().cpu().numpy()\n\n    return cost_matrix, None\n\n\ndef category_gate(cost_matrix, tracks, detections):\n    \"\"\"\n    :param tracks: list[STrack]\n    :param detections: list[BaseTrack]\n    :param metric:\n    :return: cost_matrix np.ndarray\n    \"\"\"\n    if cost_matrix.size == 0:\n        return cost_matrix\n\n    det_categories = np.array([d.category for d in detections])\n    trk_categories = np.array([t.category for t in tracks])\n\n    cost_matrix = cost_matrix + np.abs(\n            det_categories[None, :] - trk_categories[:, None])\n    return cost_matrix\n\n\n"
  },
  {
    "path": "unitrack/core/motion/kalman_filter.py",
    "content": "# vim: expandtab:ts=4:sw=4\nimport numpy as np\nimport scipy.linalg\n\n\n\"\"\"\nTable for the 0.95 quantile of the chi-square distribution with N degrees of\nfreedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv\nfunction and used as Mahalanobis gating threshold.\n\"\"\"\nchi2inv95 = {\n    1: 3.8415,\n    2: 5.9915,\n    3: 7.8147,\n    4: 9.4877,\n    5: 11.070,\n    6: 12.592,\n    7: 14.067,\n    8: 15.507,\n    9: 16.919}\n\n\nclass KalmanFilter(object):\n    \"\"\"\n    A simple Kalman filter for tracking bounding boxes in image space.\n\n    The 8-dimensional state space\n\n        x, y, a, h, vx, vy, va, vh\n\n    contains the bounding box center position (x, y), aspect ratio a, height h,\n    and their respective velocities.\n\n    Object motion follows a constant velocity model. The bounding box location\n    (x, y, a, h) is taken as direct observation of the state space (linear\n    observation model).\n\n    \"\"\"\n\n    def __init__(self):\n        ndim, dt = 4, 1.\n\n        # Create Kalman filter model matrices.\n        self._motion_mat = np.eye(2 * ndim, 2 * ndim)\n        for i in range(ndim):\n            self._motion_mat[i, ndim + i] = dt\n        self._update_mat = np.eye(ndim, 2 * ndim)\n\n        # Motion and observation uncertainty are chosen relative to the current\n        # state estimate. These weights control the amount of uncertainty in\n        # the model. This is a bit hacky.\n        self._std_weight_position = 1. / 20\n        self._std_weight_velocity = 1. / 160\n\n    def initiate(self, measurement):\n        \"\"\"Create track from unassociated measurement.\n\n        Parameters\n        ----------\n        measurement : ndarray\n            Bounding box coordinates (x, y, a, h) with center position (x, y),\n            aspect ratio a, and height h.\n\n        Returns\n        -------\n        (ndarray, ndarray)\n            Returns the mean vector (8 dimensional) and covariance matrix (8x8\n            dimensional) of the new track. Unobserved velocities are initialized\n            to 0 mean.\n\n        \"\"\"\n        mean_pos = measurement\n        mean_vel = np.zeros_like(mean_pos)\n        mean = np.r_[mean_pos, mean_vel]\n\n        std = [\n            2 * self._std_weight_position * measurement[3],\n            2 * self._std_weight_position * measurement[3],\n            1e-2,\n            2 * self._std_weight_position * measurement[3],\n            10 * self._std_weight_velocity * measurement[3],\n            10 * self._std_weight_velocity * measurement[3],\n            1e-5,\n            10 * self._std_weight_velocity * measurement[3]]\n        covariance = np.diag(np.square(std))\n        return mean, covariance\n\n    def predict(self, mean, covariance):\n        \"\"\"Run Kalman filter prediction step.\n\n        Parameters\n        ----------\n        mean : ndarray\n            The 8 dimensional mean vector of the object state at the previous\n            time step.\n        covariance : ndarray\n            The 8x8 dimensional covariance matrix of the object state at the\n            previous time step.\n\n        Returns\n        -------\n        (ndarray, ndarray)\n            Returns the mean vector and covariance matrix of the predicted\n            state. Unobserved velocities are initialized to 0 mean.\n\n        \"\"\"\n        std_pos = [\n            self._std_weight_position * mean[3],\n            self._std_weight_position * mean[3],\n            1e-2,\n            self._std_weight_position * mean[3]]\n        std_vel = [\n            self._std_weight_velocity * mean[3],\n            self._std_weight_velocity * mean[3],\n            1e-5,\n            self._std_weight_velocity * mean[3]]\n        motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))\n\n        mean = np.dot(mean, self._motion_mat.T)\n        covariance = np.linalg.multi_dot((\n            self._motion_mat, covariance, self._motion_mat.T)) + motion_cov\n\n        return mean, covariance\n\n    def project(self, mean, covariance):\n        \"\"\"Project state distribution to measurement space.\n\n        Parameters\n        ----------\n        mean : ndarray\n            The state's mean vector (8 dimensional array).\n        covariance : ndarray\n            The state's covariance matrix (8x8 dimensional).\n\n        Returns\n        -------\n        (ndarray, ndarray)\n            Returns the projected mean and covariance matrix of the given state\n            estimate.\n\n        \"\"\"\n        std = [\n            self._std_weight_position * mean[3],\n            self._std_weight_position * mean[3],\n            1e-1,\n            self._std_weight_position * mean[3]]\n        innovation_cov = np.diag(np.square(std))\n\n        mean = np.dot(self._update_mat, mean)\n        covariance = np.linalg.multi_dot((\n            self._update_mat, covariance, self._update_mat.T))\n        return mean, covariance + innovation_cov\n    \n    def multi_predict(self, mean, covariance):\n        \"\"\"Run Kalman filter prediction step (Vectorized version).\n\n        Parameters\n        ----------\n        mean : ndarray\n            The Nx8 dimensional mean matrix of the object states at the previous\n            time step.\n        covariance : ndarray\n            The Nx8x8 dimensional covariance matrics of the object states at the\n            previous time step.\n\n        Returns\n        -------\n        (ndarray, ndarray)\n            Returns the mean vector and covariance matrix of the predicted\n            state. Unobserved velocities are initialized to 0 mean.\n\n        \"\"\"\n        std_pos = [\n            self._std_weight_position * mean[:, 3],\n            self._std_weight_position * mean[:, 3],\n            1e-2 * np.ones_like(mean[:, 3]),\n            self._std_weight_position * mean[:, 3]]\n        std_vel = [\n            self._std_weight_velocity * mean[:, 3],\n            self._std_weight_velocity * mean[:, 3],\n            1e-5 * np.ones_like(mean[:, 3]),\n            self._std_weight_velocity * mean[:, 3]]\n        sqr = np.square(np.r_[std_pos, std_vel]).T\n        \n        motion_cov = []\n        for i in range(len(mean)):\n            motion_cov.append(np.diag(sqr[i]))\n        motion_cov = np.asarray(motion_cov)\n            \n        mean = np.dot(mean, self._motion_mat.T)\n        left = np.dot(self._motion_mat, covariance).transpose((1,0,2))\n        covariance = np.dot(left, self._motion_mat.T) + motion_cov\n\n        return mean, covariance\n\n    def update(self, mean, covariance, measurement):\n        \"\"\"Run Kalman filter correction step.\n\n        Parameters\n        ----------\n        mean : ndarray\n            The predicted state's mean vector (8 dimensional).\n        covariance : ndarray\n            The state's covariance matrix (8x8 dimensional).\n        measurement : ndarray\n            The 4 dimensional measurement vector (x, y, a, h), where (x, y)\n            is the center position, a the aspect ratio, and h the height of the\n            bounding box.\n\n        Returns\n        -------\n        (ndarray, ndarray)\n            Returns the measurement-corrected state distribution.\n\n        \"\"\"\n        projected_mean, projected_cov = self.project(mean, covariance)\n\n        chol_factor, lower = scipy.linalg.cho_factor(\n            projected_cov, lower=True, check_finite=False)\n        kalman_gain = scipy.linalg.cho_solve(\n            (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,\n            check_finite=False).T\n        innovation = measurement - projected_mean\n\n        new_mean = mean + np.dot(innovation, kalman_gain.T)\n        new_covariance = covariance - np.linalg.multi_dot((\n            kalman_gain, projected_cov, kalman_gain.T))\n        return new_mean, new_covariance\n\n    def gating_distance(self, mean, covariance, measurements,\n                        only_position=False, metric='maha'):\n        \"\"\"Compute gating distance between state distribution and measurements.\n\n        A suitable distance threshold can be obtained from `chi2inv95`. If\n        `only_position` is False, the chi-square distribution has 4 degrees of\n        freedom, otherwise 2.\n\n        Parameters\n        ----------\n        mean : ndarray\n            Mean vector over the state distribution (8 dimensional).\n        covariance : ndarray\n            Covariance of the state distribution (8x8 dimensional).\n        measurements : ndarray\n            An Nx4 dimensional matrix of N measurements, each in\n            format (x, y, a, h) where (x, y) is the bounding box center\n            position, a the aspect ratio, and h the height.\n        only_position : Optional[bool]\n            If True, distance computation is done with respect to the bounding\n            box center position only.\n\n        Returns\n        -------\n        ndarray\n            Returns an array of length N, where the i-th element contains the\n            squared Mahalanobis distance between (mean, covariance) and\n            `measurements[i]`.\n\n        \"\"\"\n        mean, covariance = self.project(mean, covariance)\n        if only_position:\n            mean, covariance = mean[:2], covariance[:2, :2]\n            measurements = measurements[:, :2]\n        \n        d = measurements - mean\n        if metric == 'gaussian':\n            return np.sum(d * d, axis=1)\n        elif metric == 'maha':\n            cholesky_factor = np.linalg.cholesky(covariance)\n            z = scipy.linalg.solve_triangular(\n                cholesky_factor, d.T, lower=True, check_finite=False,\n                overwrite_b=True)\n            squared_maha = np.sum(z * z, axis=0)\n            return squared_maha\n        else:\n            raise ValueError('invalid distance metric')\n\n"
  },
  {
    "path": "unitrack/core/propagation/__init__.py",
    "content": "###################################################################\n# File Name: __init__.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Mon Jan 18 15:57:34 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\nfrom .propagate_box import propagate_box\nfrom .propagate_mask import propagate_mask\nfrom .propagate_pose import propagate_pose\n\ndef propagate(temp_feats, obs, img, model, format='box'):\n    if format == 'box':\n        return propagate_box(temp_feats, obs, img, model)\n    elif format == 'mask':\n        return propagate_box(temp_feats, obs, img, model)\n    elif format == 'pose':\n        return propagate_pose(temp_feats, obs, img, model)\n    else:\n        raise ValueError('Observation format not supported.')\n"
  },
  {
    "path": "unitrack/core/propagation/propagate_box.py",
    "content": "###################################################################\n# File Name: propagate_box.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Mon Jan 18 16:01:46 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\ndef propagate_box(temp_feats, box, img, model):\n    pass\n"
  },
  {
    "path": "unitrack/core/propagation/propagate_mask.py",
    "content": "###################################################################\n# File Name: propagate_box.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Mon Jan 18 16:01:46 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\ndef propagate_mask(temp_feats, mask, img, model):\n    pass\n"
  },
  {
    "path": "unitrack/core/propagation/propagate_pose.py",
    "content": "###################################################################\n# File Name: propagate_box.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Mon Jan 18 16:01:46 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\ndef propagate_pose(temp_feats, pose, img, model):\n    pass\n"
  },
  {
    "path": "unitrack/mask.py",
    "content": "###################################################################\n# File Name: mask.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Fri Jan 29 15:16:53 2021\n###################################################################\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom unitrack.utils.box import *\nfrom unitrack.utils.mask import *\nfrom .basetrack import *\nfrom .multitracker import AssociationTracker\n\n\nclass MaskAssociationTracker(AssociationTracker):\n    def __init__(self, opt):\n        super(MaskAssociationTracker, self).__init__(opt)\n\n    def extract_emb(self, img, obs):\n        img = img.to(self.opt.device).float()\n        with torch.no_grad():\n            feat = self.app_model(img)\n        _, d, h, w = feat.shape\n        obs = torch.from_numpy(obs).to(self.opt.device).float()\n        obs = F.interpolate(obs.unsqueeze(1), size=(h,w), mode='nearest')\n        template_scale = np.prod(self.opt.feat_size)\n        embs = []\n        for ob in obs:\n            obfeat = ob*feat\n            scale = ob.sum()\n            if scale > 0:\n                if scale > self.opt.max_mask_area:\n                    scale_factor = np.sqrt(self.opt.max_mask_area/scale.item())\n                else:\n                    scale_factor = 1\n                norm_obfeat = F.interpolate(obfeat, scale_factor=scale_factor, mode='bilinear')\n                norm_mask = F.interpolate(ob.unsqueeze(1), scale_factor=scale_factor, mode='nearest')\n                emb = norm_obfeat[:,:, norm_mask.squeeze(0).squeeze(0).ge(0.5)]\n                # print(\"embedding\", emb.shape)\n                embs.append(emb.cpu())\n            else: \n                embs.append(torch.randn(d, template_scale))\n        return obs, embs\n\n    def prepare_obs(self, img, img0, obs, embs=None):\n        ''' Step 1: Network forward, get detections & embeddings'''\n        if obs.shape[0] > 0:\n            masks, embs = self.extract_emb(img, obs)\n            boxes = mask2box(masks)\n            keep_idx = remove_duplicated_box(boxes, iou_th=0.7)\n            boxes, masks, obs = boxes[keep_idx], masks[keep_idx], obs[keep_idx]\n            embs = [embs[k] for k in keep_idx]\n            detections = [STrack(tlbr_to_tlwh(tlbrs), 1, f, self.buffer_size, mask, ac=True) \\\n                    for (tlbrs,mask,f) in zip(boxes, obs, embs)]\n        else:\n            detections = []\n        return detections\n\n"
  },
  {
    "path": "unitrack/mask_with_train_embs.py",
    "content": "###################################################################\n# File Name: mask.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Fri Jan 29 15:16:53 2021\n###################################################################\nimport time\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom unitrack.utils.box import *\nfrom unitrack.utils.mask import *\nfrom .basetrack import *\n\nfrom unitrack.model import AppearanceModel\n\nclass AssociationTrackerWithTrainedEmbed(object):\n    def __init__(self, opt):\n        self.opt = opt\n        self.tracked_stracks = []  # type: list[STrack]\n        self.lost_stracks = []  # type: list[STrack]\n        self.removed_stracks = []  # type: list[STrack]\n\n        self.frame_id = 0\n        self.det_thresh = opt.conf_thres\n        self.buffer_size = opt.track_buffer\n        self.max_time_lost = self.buffer_size\n\n        self.kalman_filter = KalmanFilter()\n\n        # self.app_model = AppearanceModel(opt).to(opt.device)\n        # self.app_model.eval()\n\n        if not self.opt.asso_with_motion:\n            self.opt.motion_lambda = 1\n            self.opt.motion_gated = False\n\n    def extract_emb(self, img, obs):\n        raise NotImplementedError\n\n    def prepare_obs(self, img, img0, obs, embs=None):\n        raise NotImplementedError\n\n    def update(self, img, img0, obs, embs=None):\n        torch.cuda.empty_cache()\n        self.frame_id += 1\n        activated_stracks = []\n        refind_stracks = []\n        lost_stracks = []\n        removed_stracks = []\n\n        t1 = time.time()\n        detections = self.prepare_obs(img, img0, obs, embs=None)\n\n        ''' Add newly detected tracklets to tracked_stracks'''\n        unconfirmed = []\n        tracked_stracks = []  # type: list[STrack]\n        for track in self.tracked_stracks:\n            if not track.is_activated:\n                unconfirmed.append(track)\n            else:\n                tracked_stracks.append(track)\n\n        ''' Step 2: First association, with embedding'''\n        tracks = joint_stracks(tracked_stracks, self.lost_stracks)\n        dists, recons_ftrk = matching.center_emb_distance(tracks, detections)\n        if self.opt.use_kalman:\n            # Predict the current location with KF\n            STrack.multi_predict(tracks)\n            dists = matching.fuse_motion(self.kalman_filter, dists, tracks, detections,\n                                         lambda_=self.opt.motion_lambda, gate=self.opt.motion_gated)\n        if obs.shape[1] == 6:\n            dists = matching.category_gate(dists, tracks, detections)\n        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)\n\n        for itracked, idet in matches:\n            track = tracks[itracked]\n            det = detections[idet]\n            if track.state == TrackState.Tracked:\n                track.update(detections[idet], self.frame_id)\n                activated_stracks.append(track)\n            else:\n                track.re_activate(det, self.frame_id, new_id=False)\n                refind_stracks.append(track)\n\n        if self.opt.use_kalman:\n            '''(optional) Step 3: Second association, with IOU'''\n            tracks = [tracks[i] for i in u_track if tracks[i].state == TrackState.Tracked]\n            detections = [detections[i] for i in u_detection]\n            dists = matching.iou_distance(tracks, detections)\n            matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)\n\n            for itracked, idet in matches:\n                track = tracks[itracked]\n                det = detections[idet]\n                if track.state == TrackState.Tracked:\n                    track.update(det, self.frame_id)\n                    activated_stracks.append(track)\n                else:\n                    track.re_activate(det, self.frame_id, new_id=False)\n                    refind_stracks.append(track)\n\n            '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''\n            detections = [detections[i] for i in u_detection]\n            dists = matching.iou_distance(unconfirmed, detections)\n            matches, u_unconfirmed, u_detection = matching.linear_assignment(\n                dists, thresh=self.opt.confirm_iou_thres)\n            for itracked, idet in matches:\n                unconfirmed[itracked].update(detections[idet], self.frame_id)\n                activated_stracks.append(unconfirmed[itracked])\n            for it in u_unconfirmed:\n                track = unconfirmed[it]\n                track.mark_removed()\n                removed_stracks.append(track)\n\n        for it in u_track:\n            track = tracks[it]\n            if not track.state == TrackState.Lost:\n                track.mark_lost()\n                lost_stracks.append(track)\n\n        \"\"\" Step 4: Init new stracks\"\"\"\n        for inew in u_detection:\n            track = detections[inew]\n            if track.score < self.det_thresh:\n                continue\n            track.activate(self.kalman_filter, self.frame_id)\n            activated_stracks.append(track)\n\n        \"\"\" Step 5: Update state\"\"\"\n        for track in self.lost_stracks:\n            if self.frame_id - track.end_frame > self.max_time_lost:\n                track.mark_removed()\n                removed_stracks.append(track)\n\n        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]\n        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_stracks)\n        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)\n        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)\n        self.lost_stracks.extend(lost_stracks)\n        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)\n        self.removed_stracks.extend(removed_stracks)\n        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(\n            self.tracked_stracks, self.lost_stracks, ioudist=self.opt.dup_iou_thres)\n\n        # get scores of lost tracks\n        output_stracks = [track for track in self.tracked_stracks if track.is_activated]\n\n        return output_stracks\n\n    def reset_all(self, ):\n        self.tracked_stracks = []  # type: list[STrack]\n        self.lost_stracks = []  # type: list[STrack]\n        self.removed_stracks = []  # type: list[STrack]\n        self.frame_id = 0\n\n\nclass MaskAssociationTracker(AssociationTrackerWithTrainedEmbed):\n    def __init__(self, opt):\n        super(MaskAssociationTracker, self).__init__(opt)\n\n    def extract_emb(self, img, obs, embs):\n        img = img.to(self.opt.device).float()\n        obs = obs.to(self.opt.device).float()\n        embs = embs.to(self.opt.device).float().unsqueeze(-1)\n        # print(img.shape)\n        # print(\"obs\", obs.shape)\n        # print(\"embs\", embs.shape)\n        # exit()\n        # obs = F.interpolate(obs.unsqueeze(1), size=(h,w), mode='nearest')\n        # template_scale = np.prod(self.opt.feat_size)\n        embs_list = []\n        for emb in embs:\n            # obfeat = ob\n            embs_list.append(emb.cpu())\n            # scale = ob.sum()\n            # if scale > 0:\n            #     if scale > self.opt.max_mask_area:\n            #         scale_factor = np.sqrt(self.opt.max_mask_area/scale.item())\n            #     else:\n            #         scale_factor = 1\n            #     norm_obfeat = F.interpolate(obfeat, scale_factor=scale_factor, mode='bilinear')\n            #     norm_mask = F.interpolate(ob.unsqueeze(1), scale_factor=scale_factor, mode='nearest')\n            #     emb = norm_obfeat[:,:, norm_mask.squeeze(0).squeeze(0).ge(0.5)]\n            #     embs.append(emb.cpu())\n            # else:\n            #     embs.append(torch.randn(d, template_scale))\n        return obs, embs_list\n\n    def prepare_obs(self, img, img0, obs, embs=None):\n        ''' Step 1: Network forward, get detections & embeddings'''\n        if obs.shape[0] > 0:\n            if embs is not None:\n                masks, embs = self.extract_emb(img, obs, embs)\n            boxes = mask2box(masks)\n            keep_idx = remove_duplicated_box(boxes, iou_th=0.7)\n            boxes, masks, obs = boxes[keep_idx], masks[keep_idx], obs[keep_idx]\n            embs = [embs[k] for k in keep_idx]\n            detections = [STrack(tlbr_to_tlwh(tlbrs), 1, f, self.buffer_size, mask, ac=True) \\\n                    for (tlbrs,mask,f) in zip(boxes, obs, embs)]\n        else:\n            detections = []\n        return detections\n\n\n"
  },
  {
    "path": "unitrack/model/__init__.py",
    "content": "###################################################################\n# File Name: __init__.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Thu Dec 24 14:24:44 2020\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\nfrom .model import *\nfrom .resnet import *\n"
  },
  {
    "path": "unitrack/model/functional.py",
    "content": "###################################################################\n# File Name: functional.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Mon Jun 21 21:04:09 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef hard_prop(pred):\n    pred_max = pred.max(axis=0)[0]\n    pred[pred <  pred_max] = 0\n    pred[pred >= pred_max] = 1\n    pred /= pred.sum(0)[None]\n    return pred\n\ndef context_index_bank(n_context, long_mem, N):\n    '''\n    Construct bank of source frames indices, for each target frame\n    '''\n    ll = []   # \"long term\" context (i.e. first frame)\n    for t in long_mem:\n        assert 0 <= t < N, 'context frame out of bounds'\n        idx = torch.zeros(N, 1).long()\n        if t > 0:\n            idx += t + (n_context+1)\n            idx[:n_context+t+1] = 0\n        ll.append(idx)\n    # \"short\" context    \n    ss = [(torch.arange(n_context)[None].repeat(N, 1) +  \\\n            torch.arange(N)[:, None])[:, :]]\n    return ll + ss\n\n\ndef mem_efficient_batched_affinity(\n        query, keys, mask, temperature, topk, long_mem, device):\n    '''\n    Mini-batched computation of affinity, for memory efficiency\n    '''\n    bsize, pbsize = 10, 100 #keys.shape[2] // 2\n    Ws, Is = [], []\n\n    for b in range(0, keys.shape[2], bsize):\n        _k, _q = keys[:, :, b:b+bsize].to(device), query[:, :, b:b+bsize].to(device)\n        w_s, i_s = [], []\n\n        for pb in range(0, _k.shape[-1], pbsize):\n            A = torch.einsum('ijklm,ijkn->iklmn', _k, _q[..., pb:pb+pbsize]) \n            A[0, :, len(long_mem):] += mask[..., pb:pb+pbsize].to(device)\n\n            _, N, T, h1w1, hw = A.shape\n            A = A.view(N, T*h1w1, hw)\n            A /= temperature\n\n            weights, ids = torch.topk(A, topk, dim=-2)\n            weights = F.softmax(weights, dim=-2)\n            \n            w_s.append(weights.cpu())\n            i_s.append(ids.cpu())\n\n        weights = torch.cat(w_s, dim=-1)\n        ids = torch.cat(i_s, dim=-1)\n        Ws += [w for w in weights]\n        Is += [ii for ii in ids]\n\n    return Ws, Is\n\ndef batched_affinity(query, keys, mask, temperature, topk, long_mem, device):\n    '''\n    Mini-batched computation of affinity, for memory efficiency\n    (less aggressively mini-batched)\n    '''\n    bsize = 2\n    Ws, Is = [], []\n    for b in range(0, keys.shape[2], bsize):\n        _k, _q = keys[:, :, b:b+bsize].to(device), query[:, :, b:b+bsize].to(device)\n        w_s, i_s = [], []\n\n        A = torch.einsum('ijklmn,ijkop->iklmnop', _k, _q) / temperature\n        \n        # Mask\n        A[0, :, len(long_mem):] += mask.to(device)\n\n        _, N, T, h1w1, hw = A.shape\n        A = A.view(N, T*h1w1, hw)\n        A /= temperature\n\n        weights, ids = torch.topk(A, topk, dim=-2)\n        weights = F.softmax(weights, dim=-2)\n            \n        Ws += [w for w in weights]\n        Is += [ii for ii in ids]\n    \n    return Ws, Is\n\ndef process_pose(pred, lbl_set, topk=3):\n    # generate the coordinates:\n    pred = pred[..., 1:]\n    flatlbls = pred.flatten(0,1)\n    topk = min(flatlbls.shape[0], topk)\n    \n    vals, ids = torch.topk(flatlbls, k=topk, dim=0)\n    vals /= vals.sum(0)[None]\n    xx, yy = ids % pred.shape[1], ids // pred.shape[1]\n\n    current_coord = torch.stack([(xx * vals).sum(0), (yy * vals).sum(0)], dim=0)\n    current_coord[:, flatlbls.sum(0) == 0] = -1\n\n    pred_val_sharp = np.zeros((*pred.shape[:2], 3))\n\n    for t in range(len(lbl_set) - 1):\n        x = int(current_coord[0, t])\n        y = int(current_coord[1, t])\n\n        if x >=0 and y >= 0:\n            pred_val_sharp[y, x, :] = lbl_set[t + 1]\n\n    return current_coord.cpu(), pred_val_sharp\n\nclass MaskedAttention(nn.Module):\n    '''\n    A module that implements masked attention based on spatial locality \n    TODO implement in a more efficient way (torch sparse or correlation filter)\n    '''\n    def __init__(self, radius, flat=True):\n        super(MaskedAttention, self).__init__()\n        self.radius = radius\n        self.flat = flat\n        self.masks = {}\n        self.index = {}\n\n    def mask(self, H, W):\n        if not ('%s-%s' %(H,W) in self.masks):\n            self.make(H, W)\n        return self.masks['%s-%s' %(H,W)]\n\n    def index(self, H, W):\n        if not ('%s-%s' %(H,W) in self.index):\n            self.make_index(H, W)\n        return self.index['%s-%s' %(H,W)]\n\n    def make(self, H, W):\n        if self.flat:\n            H = int(H**0.5)\n            W = int(W**0.5)\n        \n        gx, gy = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))\n        D = ( (gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2 ).float() ** 0.5\n        D = (D < self.radius)[None].float()\n\n        if self.flat:\n            D = self.flatten(D)\n        self.masks['%s-%s' %(H,W)] = D\n\n        return D\n\n    def flatten(self, D):\n        return torch.flatten(torch.flatten(D, 1, 2), -2, -1)\n\n    def make_index(self, H, W, pad=False):\n        mask = self.mask(H, W).view(1, -1).byte()\n        idx = torch.arange(0, mask.numel())[mask[0]][None]\n\n        self.index['%s-%s' %(H,W)] = idx\n\n        return idx\n        \n    def forward(self, x):\n        H, W = x.shape[-2:]\n        sid = '%s-%s' % (H,W)\n        if sid not in self.masks:\n            self.masks[sid] = self.make(H, W).to(x.device)\n        mask = self.masks[sid]\n        return x * mask[0]\n"
  },
  {
    "path": "unitrack/model/hrnet.py",
    "content": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed under the MIT License.\n# Written by Bin Xiao (Bin.Xiao@microsoft.com)\n# Modified by Ke Sun (sunk@mail.ustc.edu.cn)\n# Modified by Zhongdao Wang(wcd17@mails.tsinghua.edu.cn)\n# ------------------------------------------------------------------------------\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport pdb\nimport logging\nimport functools\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch._utils\nimport torch.nn.functional as F\n\nBN_MOMENTUM = 0.1\nlogger = logging.getLogger(__name__)\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,\n                               bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion,\n                               momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass HighResolutionModule(nn.Module):\n    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,\n                 num_channels, fuse_method, multi_scale_output=True):\n        super(HighResolutionModule, self).__init__()\n        self._check_branches(\n            num_branches, blocks, num_blocks, num_inchannels, num_channels)\n\n        self.num_inchannels = num_inchannels\n        self.fuse_method = fuse_method\n        self.num_branches = num_branches\n\n        self.multi_scale_output = multi_scale_output\n\n        self.branches = self._make_branches(\n            num_branches, blocks, num_blocks, num_channels)\n        self.fuse_layers = self._make_fuse_layers()\n        self.relu = nn.ReLU(False)\n\n    def _check_branches(self, num_branches, blocks, num_blocks,\n                        num_inchannels, num_channels):\n        if num_branches != len(num_blocks):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(\n                num_branches, len(num_blocks))\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_channels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(\n                num_branches, len(num_channels))\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_inchannels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(\n                num_branches, len(num_inchannels))\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n\n    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,\n                         stride=1):\n        downsample = None\n        if stride != 1 or \\\n           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.num_inchannels[branch_index],\n                          num_channels[branch_index] * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,\n                            momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(self.num_inchannels[branch_index],\n                            num_channels[branch_index], stride, downsample))\n        self.num_inchannels[branch_index] = \\\n            num_channels[branch_index] * block.expansion\n        for i in range(1, num_blocks[branch_index]):\n            layers.append(block(self.num_inchannels[branch_index],\n                                num_channels[branch_index]))\n\n        return nn.Sequential(*layers)\n\n    def _make_branches(self, num_branches, block, num_blocks, num_channels):\n        branches = []\n\n        for i in range(num_branches):\n            branches.append(\n                self._make_one_branch(i, block, num_blocks, num_channels))\n\n        return nn.ModuleList(branches)\n\n    def _make_fuse_layers(self):\n        if self.num_branches == 1:\n            return None\n\n        num_branches = self.num_branches\n        num_inchannels = self.num_inchannels\n        fuse_layers = []\n        for i in range(num_branches if self.multi_scale_output else 1):\n            fuse_layer = []\n            for j in range(num_branches):\n                if j > i:\n                    fuse_layer.append(nn.Sequential(\n                        nn.Conv2d(num_inchannels[j],\n                                  num_inchannels[i],\n                                  1,\n                                  1,\n                                  0,\n                                  bias=False),\n                        nn.BatchNorm2d(num_inchannels[i], \n                                       momentum=BN_MOMENTUM),\n                        nn.Upsample(scale_factor=2**(j-i), mode='nearest')))\n                elif j == i:\n                    fuse_layer.append(None)\n                else:\n                    conv3x3s = []\n                    for k in range(i-j):\n                        if k == i - j - 1:\n                            num_outchannels_conv3x3 = num_inchannels[i]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                nn.BatchNorm2d(num_outchannels_conv3x3, \n                                            momentum=BN_MOMENTUM)))\n                        else:\n                            num_outchannels_conv3x3 = num_inchannels[j]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                nn.BatchNorm2d(num_outchannels_conv3x3,\n                                            momentum=BN_MOMENTUM),\n                                nn.ReLU(False)))\n                    fuse_layer.append(nn.Sequential(*conv3x3s))\n            fuse_layers.append(nn.ModuleList(fuse_layer))\n\n        return nn.ModuleList(fuse_layers)\n\n    def get_num_inchannels(self):\n        return self.num_inchannels\n\n    def forward(self, x):\n        if self.num_branches == 1:\n            return [self.branches[0](x[0])]\n\n        for i in range(self.num_branches):\n            x[i] = self.branches[i](x[i])\n\n        x_fuse = []\n        for i in range(len(self.fuse_layers)):\n            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])\n            for j in range(1, self.num_branches):\n                if i == j:\n                    y = y + x[j]\n                else:\n                    fused = self.fuse_layers[i][j](x[j])\n                    fh, fw = fused.shape[-2:]\n                    yh, yw = y.shape[-2:]\n                    if fh > yh:\n                        fused = fused[:,:,(fh-yh)//2:-(fh-yh)//2,:]\n                    if fw > yw:\n                        fused = fused[:,:,:,(fw-yw)//2:-(fw-yw)//2]\n                    y = y + fused\n            x_fuse.append(self.relu(y))\n\n        return x_fuse\n\n\nblocks_dict = {\n    'BASIC': BasicBlock,\n    'BOTTLENECK': Bottleneck\n}\n\n\nclass HighResolutionNet(nn.Module):\n\n    def __init__(self, cfg, **kwargs):\n        super(HighResolutionNet, self).__init__()\n        self.cfg = cfg\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,\n                               bias=False)\n        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=True)\n\n        self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']\n        num_channels = self.stage1_cfg['NUM_CHANNELS'][0]\n        block = blocks_dict[self.stage1_cfg['BLOCK']]\n        num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]\n        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)\n        stage1_out_channel = block.expansion*num_channels\n\n        self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']\n        num_channels = self.stage2_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage2_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition1 = self._make_transition_layer(\n            [stage1_out_channel], num_channels)\n        self.stage2, pre_stage_channels = self._make_stage(\n            self.stage2_cfg, num_channels)\n\n        self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']\n        num_channels = self.stage3_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage3_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition2 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage3, pre_stage_channels = self._make_stage(\n            self.stage3_cfg, num_channels)\n\n        self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']\n        num_channels = self.stage4_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage4_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition3 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage4, pre_stage_channels = self._make_stage(\n            self.stage4_cfg, num_channels, multi_scale_output=True)\n\n        # Classification Head\n        self.incre_modules, self.downsamp_modules, \\\n            self.final_layer = self._make_head(pre_stage_channels)\n\n        self.classifier = nn.Linear(2048, 1000)\n\n    def _make_head(self, pre_stage_channels):\n        head_block = Bottleneck\n        head_channels = [32, 64, 128, 256]\n\n        # Increasing the #channels on each resolution \n        # from C, 2C, 4C, 8C to 128, 256, 512, 1024\n        incre_modules = []\n        for i, channels  in enumerate(pre_stage_channels):\n            incre_module = self._make_layer(head_block,\n                                            channels,\n                                            head_channels[i],\n                                            1,\n                                            stride=1)\n            incre_modules.append(incre_module)\n        incre_modules = nn.ModuleList(incre_modules)\n            \n        # downsampling modules\n        downsamp_modules = []\n        for i in range(len(pre_stage_channels)-1):\n            in_channels = head_channels[i] * head_block.expansion\n            out_channels = head_channels[i+1] * head_block.expansion\n\n            downsamp_module = nn.Sequential(\n                nn.Conv2d(in_channels=in_channels,\n                          out_channels=out_channels,\n                          kernel_size=3,\n                          stride=2,\n                          padding=1),\n                nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),\n                nn.ReLU(inplace=True)\n            )\n\n            downsamp_modules.append(downsamp_module)\n        downsamp_modules = nn.ModuleList(downsamp_modules)\n\n        final_layer = nn.Sequential(\n            nn.Conv2d(\n                in_channels=head_channels[3] * head_block.expansion,\n                out_channels=2048,\n                kernel_size=1,\n                stride=1,\n                padding=0\n            ),\n            nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=True)\n        )\n\n        return incre_modules, downsamp_modules, final_layer\n\n    def _make_transition_layer(\n            self, num_channels_pre_layer, num_channels_cur_layer):\n        num_branches_cur = len(num_channels_cur_layer)\n        num_branches_pre = len(num_channels_pre_layer)\n\n        transition_layers = []\n        for i in range(num_branches_cur):\n            if i < num_branches_pre:\n                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:\n                    transition_layers.append(nn.Sequential(\n                        nn.Conv2d(num_channels_pre_layer[i],\n                                  num_channels_cur_layer[i],\n                                  3,\n                                  1,\n                                  1,\n                                  bias=False),\n                        nn.BatchNorm2d(\n                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=True)))\n                else:\n                    transition_layers.append(None)\n            else:\n                conv3x3s = []\n                for j in range(i+1-num_branches_pre):\n                    inchannels = num_channels_pre_layer[-1]\n                    outchannels = num_channels_cur_layer[i] \\\n                        if j == i-num_branches_pre else inchannels\n                    conv3x3s.append(nn.Sequential(\n                        nn.Conv2d(\n                            inchannels, outchannels, 3, 2, 1, bias=False),\n                        nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=True)))\n                transition_layers.append(nn.Sequential(*conv3x3s))\n\n        return nn.ModuleList(transition_layers)\n\n    def _make_layer(self, block, inplanes, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(inplanes, planes, stride, downsample))\n        inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def _make_stage(self, layer_config, num_inchannels,\n                    multi_scale_output=True):\n        num_modules = layer_config['NUM_MODULES']\n        num_branches = layer_config['NUM_BRANCHES']\n        num_blocks = layer_config['NUM_BLOCKS']\n        num_channels = layer_config['NUM_CHANNELS']\n        block = blocks_dict[layer_config['BLOCK']]\n        fuse_method = layer_config['FUSE_METHOD']\n\n        modules = []\n        for i in range(num_modules):\n            # multi_scale_output is only used last module\n            if not multi_scale_output and i == num_modules - 1:\n                reset_multi_scale_output = False\n            else:\n                reset_multi_scale_output = True\n\n            modules.append(\n                HighResolutionModule(num_branches,\n                                      block,\n                                      num_blocks,\n                                      num_inchannels,\n                                      num_channels,\n                                      fuse_method,\n                                      reset_multi_scale_output)\n            )\n            num_inchannels = modules[-1].get_num_inchannels()\n\n        return nn.Sequential(*modules), num_inchannels\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = self.layer1(x)\n\n        x_list = []\n        for i in range(self.stage2_cfg['NUM_BRANCHES']):\n            if self.transition1[i] is not None:\n                x_list.append(self.transition1[i](x))\n            else:\n                x_list.append(x)\n        y_list = self.stage2(x_list)\n\n        x_list = []\n        for i in range(self.stage3_cfg['NUM_BRANCHES']):\n            if self.transition2[i] is not None:\n                x_list.append(self.transition2[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage3(x_list)\n\n        x_list = []\n        for i in range(self.stage4_cfg['NUM_BRANCHES']):\n            if self.transition3[i] is not None:\n                x_list.append(self.transition3[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage4(x_list)\n\n        # Classification Head\n        y_list_out = {}\n        y_list_out[0] = self.incre_modules[0](y_list[0])\n        for i in range(len(self.downsamp_modules)):\n            y_list_out[i+1] = self.incre_modules[i+1](y_list[i+1]) + \\\n                        self.downsamp_modules[i](y_list_out[i])\n\n        #y = self.final_layer(y)\n\n        ret = y_list_out[self.cfg['MODEL']['RETURN_STAGE']]\n\n        ret_size = y_list_out[1].shape[-2:]\n        ret = F.interpolate(ret, ret_size, mode='bilinear')\n        return ret\n\n    def init_weights(self, pretrained='',):\n        print('=> init weights from normal distribution')\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(\n                    m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n        if os.path.isfile(pretrained):\n            pretrained_dict = torch.load(pretrained)\n            print('=> loading pretrained model {}'.format(pretrained))\n            model_dict = self.state_dict()\n            pretrained_dict = {k: v for k, v in pretrained_dict.items()\n                               if k in model_dict.keys()}\n            for k, _ in pretrained_dict.items():\n                print(\n                    '=> loading {} pretrained model {}'.format(k, pretrained))\n            model_dict.update(pretrained_dict)\n            self.load_state_dict(model_dict)\n\n\nconfig = {\n'hrnet_w18': {\n    'MODEL':{\n        'EXTRA':{\n            'STAGE1':{\n                'NUM_MODULES':1,\n                'NUM_BRANCHES':1,\n                'BLOCK': 'BOTTLENECK',\n                'NUM_BLOCKS':[4,],\n                'NUM_CHANNELS':[64,],\n                'FUSE_METHOD': 'SUM',\n                },\n            'STAGE2':{\n                'NUM_MODULES':1,\n                'NUM_BRANCHES':2,\n                'BLOCK': 'BASIC',\n                'NUM_BLOCKS':[4,4,],\n                'NUM_CHANNELS':[18, 36],\n                'FUSE_METHOD': 'SUM',\n                },\n            'STAGE3':{\n                'NUM_MODULES':4,\n                'NUM_BRANCHES':3,\n                'BLOCK': 'BASIC',\n                'NUM_BLOCKS':[4,4,4],\n                'NUM_CHANNELS':[18, 36, 72],\n                'FUSE_METHOD': 'SUM',\n                },\n            'STAGE4':{\n                'NUM_MODULES':3,\n                'NUM_BRANCHES':4,\n                'BLOCK': 'BASIC',\n                'NUM_BLOCKS':[4,4,4,4],\n                'NUM_CHANNELS':[18, 36, 72, 144],\n                'FUSE_METHOD': 'SUM',\n                },\n            }\n        } \n    },\n'hrnet_w32': {\n    'MODEL':{\n        'EXTRA':{\n            'STAGE1':{\n                'NUM_MODULES':1,\n                'NUM_BRANCHES':1,\n                'BLOCK': 'BOTTLENECK',\n                'NUM_BLOCKS':[4,],\n                'NUM_CHANNELS':[64,],\n                'FUSE_METHOD': 'SUM',\n                },\n            'STAGE2':{\n                'NUM_MODULES':1,\n                'NUM_BRANCHES':2,\n                'BLOCK': 'BASIC',\n                'NUM_BLOCKS':[4,4,],\n                'NUM_CHANNELS':[32, 64],\n                'FUSE_METHOD': 'SUM',\n                },\n            'STAGE3':{\n                'NUM_MODULES':4,\n                'NUM_BRANCHES':3,\n                'BLOCK': 'BASIC',\n                'NUM_BLOCKS':[4,4,4],\n                'NUM_CHANNELS':[32, 64, 128],\n                'FUSE_METHOD': 'SUM',\n                },\n            'STAGE4':{\n                'NUM_MODULES':3,\n                'NUM_BRANCHES':4,\n                'BLOCK': 'BASIC',\n                'NUM_BLOCKS':[4,4,4,4],\n                'NUM_CHANNELS':[32, 64, 128, 256],\n                'FUSE_METHOD': 'SUM',\n                },\n            }\n        } \n    }\n}\n\ndef get_cls_net(c, **kwargs):\n    cfg = config[c]\n    cfg['MODEL']['RETURN_STAGE'] = kwargs['return_stage']\n    model = HighResolutionNet(cfg, **kwargs)\n    model.init_weights(pretrained=kwargs['pretrained'])\n    return model\n\nif __name__ == '__main__':\n    net = get_cls_net('hrnet_w18', return_stage=2, pretrained='../weights/hrnetv2_w18_imagenet.pth')\n    pdb.set_trace()\n"
  },
  {
    "path": "unitrack/model/model.py",
    "content": "import pdb\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\n\nfrom unitrack.model import resnet\nfrom unitrack.model import hrnet\nfrom unitrack.model import random_feat_generator\n\nclass AppearanceModel(nn.Module):\n    def __init__(self, args):\n        super(AppearanceModel, self).__init__()\n        self.args = args\n        \n        self.model = make_encoder(args).to(self.args.device)\n    def forward(self, x):\n        z = self.model(x)\n        return z\n\ndef partial_load(pretrained_dict, model, skip_keys=[], log=False):\n    model_dict = model.state_dict()\n    \n    # 1. filter out unnecessary keys\n    filtered_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and not any([sk in k for sk in skip_keys])}\n    skipped_keys = [k for k in pretrained_dict if k not in filtered_dict]\n    unload_keys = [k for k in model_dict if k not in pretrained_dict]\n    \n    # 2. overwrite entries in the existing state dict\n    model_dict.update(filtered_dict)\n\n    # 3. load the new state dict\n    model.load_state_dict(model_dict)\n\n    if log:\n        print('\\nSkipped keys: ', skipped_keys)\n        print('\\nLoading keys: ', filtered_dict.keys())\n        print('\\nUnLoaded keys: ', unload_keys)\n\ndef load_vince_model(path):\n    checkpoint = torch.load(path, map_location={'cuda:0': 'cpu'})\n    checkpoint = {k.replace('feature_extractor.module.model.', ''): checkpoint[k] for k in checkpoint if 'feature_extractor' in k}\n    return checkpoint\n\n\ndef load_uvc_model(ckpt_path):\n    net = resnet.resnet18()\n    net.avgpool, net.fc = None, None\n\n    ckpt = torch.load(ckpt_path, map_location='cpu')\n    state_dict = {k.replace('module.gray_encoder.', ''):v for k,v in ckpt['state_dict'].items() if 'gray_encoder' in k}\n    partial_load(state_dict, net)\n\n    return net\n\n\ndef load_tc_model(ckpt_path):\n    model_state = torch.load(ckpt_path, map_location='cpu')['state_dict']\n    \n    net = resnet.resnet50()\n    net_state = net.state_dict()\n\n    for k in [k for k in model_state.keys() if 'encoderVideo' in k]:\n        kk = k.replace('module.encoderVideo.', '')\n        tmp = model_state[k]\n        if net_state[kk].shape != model_state[k].shape and net_state[kk].dim() == 4 and model_state[k].dim() == 5:\n            tmp = model_state[k].squeeze(2)\n        net_state[kk][:] = tmp[:]\n        \n    net.load_state_dict(net_state)\n\n    return net\n\nclass From3D(nn.Module):\n    ''' Use a 2D convnet as a 3D convnet '''\n    def __init__(self, resnet):\n        super(From3D, self).__init__()\n        self.model = resnet\n    \n    def forward(self, x):\n        N, C, T, h, w = x.shape\n        xx = x.permute(0, 2, 1, 3, 4).contiguous().view(-1, C, h, w)\n        m = self.model(xx)\n\n        return m.view(N, T, *m.shape[-3:]).permute(0, 2, 1, 3, 4)\n\n\ndef make_encoder(args):\n    SSL_MODELS = ['byol', 'deepcluster-v2', 'infomin', 'insdis', 'moco-v1', 'moco-v2',\n            'pcl-v1', 'pcl-v2','pirl', 'sela-v2', 'swav', 'simclr-v1', 'simclr-v2',\n            'pixpro', 'detco', 'barlowtwins']\n    model_type = args.model_type\n    if model_type == 'crw':\n        net = resnet.resnet18()\n        if osp.isfile(args.resume):\n            ckpt = torch.load(args.resume)\n            state = {}\n            for k, v in ckpt['model'].items():\n                if 'conv1.1.weight' in k or 'conv2.1.weight' in k:\n                    state[k.replace('.1.weight', '.weight')] = v\n                if 'encoder.model' in k:\n                    state[k.replace('encoder.model.', '')] = v\n                else:\n                    state[k] = v\n            partial_load(state, net, skip_keys=['head',])\n            del ckpt\n    elif model_type == 'random18':\n        net = resnet.resnet18(pretrained=False)\n    elif model_type == 'random50':\n        net = resnet.resnet50(pretrained=False)\n    elif model_type == 'imagenet18':\n        net = resnet.resnet18(pretrained=True)\n    elif model_type == 'imagenet50':\n        net = resnet.resnet50(pretrained=True)\n    elif model_type == 'imagenet101':\n        net = resnet.resnet101(pretrained=True)\n    elif model_type == 'imagenet_resnext50':\n        net = resnet.resnext50_32x4d(pretrained=True)\n    elif model_type == 'imagenet_resnext101':\n        net = resnet.resnext101_32x8d(pretrained=True)\n    elif model_type == 'mocov2':\n        net = resnet.resnet50(pretrained=False)\n        net_ckpt = torch.load(args.resume)\n        net_state = {k.replace('module.encoder_q.', ''):v for k,v in net_ckpt['state_dict'].items() \\\n                if 'module.encoder_q' in k}\n        partial_load(net_state, net)\n    elif model_type == 'uvc':\n        net = load_uvc_model(args.resume)\n    elif model_type == 'timecycle':\n        net = load_tc_model(args.resume)\n    elif model_type in SSL_MODELS:\n        net = resnet.resnet50(pretrained=False)\n        net_ckpt = torch.load(args.resume)\n        partial_load(net_ckpt, net)\n    elif 'hrnet' in model_type:\n        net = hrnet.get_cls_net(model_type, return_stage=args.return_stage, pretrained=args.resume)\n    elif model_type == 'random':\n        net = random_feat_generator.RandomFeatGenerator(args)\n    else:\n        raise ValueError('Invalid model_type.')\n    if hasattr(net, 'modify'):\n        net.modify(remove_layers=args.remove_layers)\n\n    if 'Conv2d' in str(net) and not args.infer2D:\n        net = From3D(net)\n    return net\n"
  },
  {
    "path": "unitrack/model/random_feat_generator.py",
    "content": "###################################################################\n# File Name: random_feat_generator.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Mon May 10 16:13:46 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\n\nclass RandomFeatGenerator(nn.Module):\n    def __init__(self, args):\n        super(RandomFeatGenerator, self).__init__()\n        self.df = args.down_factor\n        self.dim = args.dim\n        self.dummy = nn.Linear(2,3)\n    def forward(self, x):\n        if len(x.shape) == 4:\n            N,C,H,W = x.shape\n        elif len(x.shape) == 5:\n            N,C,T,H,W = x.shape\n        else:\n            raise ValueError\n        c, h, w = self.dim, round(H/self.df), round(W/self.df)\n\n        if len(x.shape) == 4:\n            feat = torch.rand(N,c,h,w).cuda()\n        elif len(x.shape) == 5:\n            feat = torch.rand(N,c,T,h,w).cuda()\n        return feat\n\n    def __str__(self):\n        return ''\n"
  },
  {
    "path": "unitrack/model/resnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\nimport torchvision.models.resnet as torch_resnet\nfrom torchvision.models.resnet import BasicBlock, Bottleneck\n\nmodel_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n}\n\nclass ResNet(torch_resnet.ResNet):\n    def __init__(self, *args, **kwargs):\n        super(ResNet, self).__init__(*args, **kwargs)\n\n    def modify(self, remove_layers=[], padding=''):\n        # Set stride of layer3 and layer 4 to 1 (from 2)\n        filter_layers = lambda x: [l for l in x if getattr(self, l) is not None]\n        for layer in filter_layers(['layer3', 'layer4']):\n            for m in getattr(self, layer).modules():\n                if isinstance(m, torch.nn.Conv2d):\n                    m.stride = tuple(1 for _ in m.stride)\n        # Set padding (zeros or reflect, doesn't change much; \n        # zeros requires lower temperature)\n        if padding != '' and padding != 'no':\n            for m in self.modules():\n                if isinstance(m, torch.nn.Conv2d) and sum(m.padding) > 0:\n                    m.padding_mode = padding\n        elif padding == 'no':\n            for m in self.modules():\n                if isinstance(m, torch.nn.Conv2d) and sum(m.padding) > 0:\n                    m.padding = (0,0)\n\n        # Remove extraneous layers\n        remove_layers += ['fc', 'avgpool']\n        for layer in filter_layers(remove_layers):\n            setattr(self, layer, None)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = x if self.maxpool is None else self.maxpool(x) \n\n        x = self.layer1(x)\n        x = F.avg_pool2d(x,(2,2)) if self.layer2 is None else self.layer2(x)\n        x = x if self.layer3 is None else self.layer3(x) \n        x = x if self.layer4 is None else self.layer4(x) \n    \n        return x        \n\n\ndef _resnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n\ndef resnet18(pretrained=False, progress=True, **kwargs):\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                   **kwargs)\n\ndef resnet50(pretrained=False, progress=True, **kwargs) -> ResNet:\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\ndef resnet101(pretrained=False, progress=True, **kwargs): \n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\ndef resnet152(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnext50_32x4d(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef resnext101_32x8d(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet50_2(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet101_2(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n"
  },
  {
    "path": "unitrack/multitracker.py",
    "content": "import os\nimport pdb\nimport cv2\nimport time\nimport itertools\nimport os.path as osp\nfrom collections import deque\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torchvision import ops\n\nfrom unitrack.model import AppearanceModel, partial_load\nfrom unitrack.utils.log import logger\nfrom unitrack.core.association import matching\nfrom unitrack.core.propagation import propagate\nfrom unitrack.core.motion.kalman_filter import KalmanFilter\n\nfrom unitrack.utils.box import *\nfrom unitrack.utils.mask import *\nfrom .basetrack import *\n\n\nclass AssociationTracker(object):\n    def __init__(self, opt):\n        self.opt = opt\n        self.tracked_stracks = []  # type: list[STrack]\n        self.lost_stracks = []     # type: list[STrack]\n        self.removed_stracks = []  # type: list[STrack]\n\n        self.frame_id = 0\n        self.det_thresh = opt.conf_thres\n        self.buffer_size = opt.track_buffer\n        self.max_time_lost = self.buffer_size\n\n        self.kalman_filter = KalmanFilter()\n\n        self.app_model = AppearanceModel(opt).to(opt.device)\n        self.app_model.eval()\n        \n        if not self.opt.asso_with_motion:\n            self.opt.motion_lambda = 1\n            self.opt.motion_gated = False\n        \n    def extract_emb(self, img, obs):\n        raise NotImplementedError\n\n    def prepare_obs(self, img, img0, obs, embs=None):\n        raise NotImplementedError\n\n    def update(self, img, img0, obs, embs=None):\n        torch.cuda.empty_cache()\n        self.frame_id += 1\n        activated_stracks = []\n        refind_stracks = []\n        lost_stracks = []\n        removed_stracks = []\n \n        t1 = time.time()\n        detections = self.prepare_obs(img, img0, obs, embs=None)\n\n        ''' Add newly detected tracklets to tracked_stracks'''\n        unconfirmed = []\n        tracked_stracks = []  # type: list[STrack]\n        for track in self.tracked_stracks:\n            if not track.is_activated:\n                unconfirmed.append(track)\n            else:\n                tracked_stracks.append(track)\n\n        ''' Step 2: First association, with embedding'''\n        tracks = joint_stracks(tracked_stracks, self.lost_stracks)\n        dists, recons_ftrk = matching.reconsdot_distance(tracks, detections)\n        if self.opt.use_kalman: \n            # Predict the current location with KF\n            STrack.multi_predict(tracks)\n            dists = matching.fuse_motion(self.kalman_filter, dists, tracks, detections, \n                    lambda_=self.opt.motion_lambda, gate=self.opt.motion_gated)\n        if obs.shape[1] == 6:\n            dists = matching.category_gate(dists, tracks, detections)\n        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)\n\n        for itracked, idet in matches:\n            track = tracks[itracked]\n            det = detections[idet]\n            if track.state == TrackState.Tracked:\n                track.update(detections[idet], self.frame_id)\n                activated_stracks.append(track)\n            else:\n                track.re_activate(det, self.frame_id, new_id=False)\n                refind_stracks.append(track)\n        \n        if self.opt.use_kalman:\n            '''(optional) Step 3: Second association, with IOU'''\n            tracks = [tracks[i] for i in u_track if tracks[i].state==TrackState.Tracked]\n            detections = [detections[i] for i in u_detection]\n            dists = matching.iou_distance(tracks, detections)\n            matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)\n            \n            for itracked, idet in matches:\n                track = tracks[itracked]\n                det = detections[idet]\n                if track.state == TrackState.Tracked:\n                    track.update(det, self.frame_id)\n                    activated_stracks.append(track)\n                else:\n                    track.re_activate(det, self.frame_id, new_id=False)\n                    refind_stracks.append(track)\n\n            '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''\n            detections = [detections[i] for i in u_detection]\n            dists = matching.iou_distance(unconfirmed, detections)\n            matches, u_unconfirmed, u_detection = matching.linear_assignment(\n                    dists, thresh=self.opt.confirm_iou_thres)\n            for itracked, idet in matches:\n                unconfirmed[itracked].update(detections[idet], self.frame_id)\n                activated_stracks.append(unconfirmed[itracked])\n            for it in u_unconfirmed:\n                track = unconfirmed[it]\n                track.mark_removed()\n                removed_stracks.append(track)\n\n        for it in u_track:\n            track = tracks[it]\n            if not track.state == TrackState.Lost:\n                track.mark_lost()\n                lost_stracks.append(track)\n\n        \"\"\" Step 4: Init new stracks\"\"\"\n        for inew in u_detection:\n            track = detections[inew]\n            if track.score < self.det_thresh:\n                continue\n            track.activate(self.kalman_filter, self.frame_id)\n            activated_stracks.append(track)\n\n        \"\"\" Step 5: Update state\"\"\"\n        for track in self.lost_stracks:\n            if self.frame_id - track.end_frame > self.max_time_lost:\n                track.mark_removed()\n                removed_stracks.append(track)\n\n        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]\n        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_stracks)\n        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)\n        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)\n        self.lost_stracks.extend(lost_stracks)\n        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)\n        self.removed_stracks.extend(removed_stracks)\n        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(\n                self.tracked_stracks, self.lost_stracks, ioudist=self.opt.dup_iou_thres)\n\n        # get scores of lost tracks\n        output_stracks = [track for track in self.tracked_stracks if track.is_activated]\n\n        return output_stracks\n\n    def reset_all(self, ):\n        self.tracked_stracks = []  # type: list[STrack]\n        self.lost_stracks = []  # type: list[STrack]\n        self.removed_stracks = []  # type: list[STrack]\n        self.frame_id = 0"
  },
  {
    "path": "unitrack/utils/__init__.py",
    "content": "from collections import defaultdict, deque\nimport datetime\nimport time\nimport torch\n\nimport errno\nimport os\nimport pdb\nimport sys\n\nfrom . import visualize\nfrom . import box\nfrom . import meter\nfrom . import log\n\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ndef to_numpy(tensor):\n    if torch.is_tensor(tensor):\n        return tensor.cpu().numpy()\n    elif type(tensor).__module__ != 'numpy':\n        raise ValueError(\"Cannot convert {} to numpy array\"\n                         .format(type(tensor)))\n    return tensor\n\ndef to_torch(ndarray):\n    if type(ndarray).__module__ == 'numpy':\n        return torch.from_numpy(ndarray)\n    elif not torch.is_tensor(ndarray):\n        raise ValueError(\"Cannot convert {} to torch tensor\"\n                         .format(type(ndarray)))\n    return ndarray\n\ndef im_to_numpy(img):\n    img = to_numpy(img)\n    img = np.transpose(img, (1, 2, 0)) # H*W*C\n    return img\n\ndef im_to_torch(img):\n    img = np.transpose(img, (2, 0, 1)) # C*H*W\n    img = to_torch(img).float()\n    return img\n"
  },
  {
    "path": "unitrack/utils/box.py",
    "content": "###################################################################\n# File Name: box.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Wed Dec 23 16:27:15 2020\n###################################################################\n\nimport torch\nimport torchvision\nimport numpy as np\n\n\ndef xyxy2xywh(x):\n    # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h]\n    y = x.clone() if x.dtype is torch.float32 else x.copy()\n    y[:, 0] = (x[:, 0] + x[:, 2]) / 2\n    y[:, 1] = (x[:, 1] + x[:, 3]) / 2\n    y[:, 2] = x[:, 2] - x[:, 0]\n    y[:, 3] = x[:, 3] - x[:, 1]\n    return y\n\n\ndef xywh2xyxy(x):\n    # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2]\n    y = x.clone() if x.dtype is torch.float32 else x.copy()\n    y[:, 0] = (x[:, 0] - x[:, 2] / 2)\n    y[:, 1] = (x[:, 1] - x[:, 3] / 2)\n    y[:, 2] = (x[:, 0] + x[:, 2] / 2)\n    y[:, 3] = (x[:, 1] + x[:, 3] / 2)\n    return y\n\n\ndef tlwh2xyxy(x):\n    # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2]\n    y = x.clone() if x.dtype is torch.float32 else x.copy()\n    y[:, 2] = (x[:, 0] + x[:, 2])\n    y[:, 3] = (x[:, 1] + x[:, 3])\n    return y\n\n\ndef tlwh_to_xywh(tlwh):\n    ret = np.asarray(tlwh).copy()\n    ret[:2] += ret[2:] / 2\n    return ret\n\n\ndef tlwh_to_xyah(tlwh):\n    \"\"\"Convert bounding box to format `(center x, center y, aspect ratio,\n    height)`, where the aspect ratio is `width / height`.\n    \"\"\"\n    ret = np.asarray(tlwh).copy()\n    ret[:2] += ret[2:] / 2\n    ret[2] /= (ret[3] + 1e-6)\n    return ret\n\n\ndef tlbr_to_tlwh(tlbr):\n    ret = np.asarray(tlbr).copy()\n    ret[2:] -= ret[:2]\n    return ret\n\n\ndef tlwh_to_tlbr(tlwh):\n    ret = np.asarray(tlwh).copy()\n    ret[2:] += ret[:2]\n    return ret\n\n\ndef scale_box(scale, coords):\n    c = coords.clone()\n    c[:, [0, 2]] = coords[:, [0, 2]] * scale[0]\n    c[:, [1, 3]] = coords[:, [1, 3]] * scale[1]\n    return c\n\n\ndef scale_box_letterbox_size(img_size, coords, img0_shape):\n    gain_w = float(img_size[0]) / img0_shape[1]  # gain  = old / new\n    gain_h = float(img_size[1]) / img0_shape[0]\n    gain = min(gain_w, gain_h)\n    pad_x = (img_size[0] - img0_shape[1] * gain) / 2  # width padding\n    pad_y = (img_size[1] - img0_shape[0] * gain) / 2  # height padding\n    coords[:, 0:4] *= gain\n    coords[:, [0, 2]] += pad_x\n    coords[:, [1, 3]] += pad_y\n    return coords\n\n\ndef scale_box_input_size(img_size, coords, img0_shape):\n    # Rescale x1, y1, x2, y2 from 416 to image size\n    gain_w = float(img_size[0]) / img0_shape[1]  # gain  = old / new\n    gain_h = float(img_size[1]) / img0_shape[0]\n    gain = min(gain_w, gain_h)\n    pad_x = (img_size[0] - img0_shape[1] * gain) / 2  # width padding\n    pad_y = (img_size[1] - img0_shape[0] * gain) / 2  # height padding\n    coords[:, [0, 2]] -= pad_x\n    coords[:, [1, 3]] -= pad_y\n    coords[:, 0:4] /= gain\n    return coords\n\n\ndef clip_boxes(boxes, im_shape):\n    \"\"\"\n    Clip boxes to image boundaries.\n    \"\"\"\n    boxes = np.asarray(boxes)\n    if boxes.shape[0] == 0:\n        return boxes\n    boxes = np.copy(boxes)\n    # x1 >= 0\n    boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)\n    # y1 >= 0\n    boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)\n    # x2 < im_shape[1]\n    boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)\n    # y2 < im_shape[0]\n    boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)\n    return boxes\n\n\ndef clip_box(bbox, im_shape):\n    h, w = im_shape[:2]\n    bbox = np.copy(bbox)\n    bbox[0] = max(min(bbox[0], w - 1), 0)\n    bbox[1] = max(min(bbox[1], h - 1), 0)\n    bbox[2] = max(min(bbox[2], w - 1), 0)\n    bbox[3] = max(min(bbox[3], h - 1), 0)\n\n    return bbox\n\n\ndef int_box(box):\n    box = np.asarray(box, dtype=np.float)\n    box = np.round(box)\n    return np.asarray(box, dtype=np.int)\n\n\ndef remove_duplicated_box(boxes, iou_th=0.5):\n    if isinstance(boxes, np.ndarray):\n        boxes = torch.from_numpy(boxes)\n    jac = torchvision.ops.box_iou(boxes, boxes).float()\n    jac -= torch.eye(jac.shape[0])\n    keep = np.ones(len(boxes)) == 1\n    for i, b in enumerate(boxes):\n        if b[0] == -1 and b[1] == -1 and b[2] == 10 and b[3] == 10:\n            keep[i] = False\n    for r, row in enumerate(jac):\n        if keep[r]:\n            discard = torch.where(row > iou_th)\n            keep[discard] = False\n    return np.where(keep)[0]\n\n\ndef skltn2box(skltn):\n    dskltn = dict()\n    for s in skltn:\n        dskltn[s['id'][0]] = (int(s['x'][0]), int(s['y'][0]))\n    if len(dskltn) == 0:\n        return np.array(\n                [-1, -1, np.random.randint(1, 40), np.random.randint(1, 70)])\n\n    xmin = np.min([dskltn[k][0] for k in dskltn])\n    xmax = np.max([dskltn[k][0] for k in dskltn])\n    ymin = np.min([dskltn[k][1] for k in dskltn])\n    ymax = np.max([dskltn[k][1] for k in dskltn])\n    if xmin == xmax:\n        xmax += 10\n    if ymin == ymax:\n        ymax += 10\n    return np.array([xmin, ymin, xmax, ymax])\n"
  },
  {
    "path": "unitrack/utils/io.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Dict\nimport numpy as np\n\nfrom utils.log import logger\n\ndef mkdir_if_missing(d):\n    if not osp.exists(d):\n        os.makedirs(d)\n\ndef write_mots_results(filename, results, data_type='mot'):\n    if not filename:\n        return\n    path = os.path.dirname(filename)\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n    if data_type in ('mot'):\n        save_format = '{frame} {id} {cid} {imh} {imw} {rle}\\n'\n    else:\n        raise ValueError(data_type)\n\n    with open(filename, 'w') as f:\n        for frame_id, tlwhs, rles, track_ids in results:\n            for rle, track_id in zip(rles, track_ids):\n                if track_id < 0:\n                    continue\n                rle_str = rle['counts']\n                imh, imw = rle['size']\n                line = save_format.format(frame=frame_id, id=track_id+2000, cid=2, imh=imh, imw=imw, rle=rle_str)\n                f.write(line)\n    logger.info('Save results to {}'.format(filename))\n\ndef write_mot_results(filename, results, data_type='mot'):\n    if not filename:\n        return\n    path = os.path.dirname(filename)\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n    if data_type in ('mot', 'mcmot', 'lab'):\n        save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\\n'\n    elif data_type == 'kitti':\n        save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\\n'\n    else:\n        raise ValueError(data_type)\n\n    with open(filename, 'w') as f:\n        for frame_id, tlwhs, track_ids in results:\n            if data_type == 'kitti':\n                frame_id -= 1\n            for tlwh, track_id in zip(tlwhs, track_ids):\n                if track_id < 0:\n                    continue\n                x1, y1, w, h = tlwh\n                x2, y2 = x1 + w, y1 + h\n                line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)\n                f.write(line)\n    logger.info('Save results to {}'.format(filename))\n\n\ndef read_mot_results(filename, data_type='mot', is_gt=False, is_ignore=False):\n    if data_type in ('mot', 'lab'):\n        read_fun = _read_mot_results\n    else:\n        raise ValueError('Unknown data type: {}'.format(data_type))\n\n    return read_fun(filename, is_gt, is_ignore)\n\n\n\"\"\"\nlabels={'ped', ...\t\t\t% 1\n'person_on_vhcl', ...\t% 2\n'car', ...\t\t\t\t% 3\n'bicycle', ...\t\t\t% 4\n'mbike', ...\t\t\t% 5\n'non_mot_vhcl', ...\t\t% 6\n'static_person', ...\t% 7\n'distractor', ...\t\t% 8\n'occluder', ...\t\t\t% 9\n'occluder_on_grnd', ...\t\t%10\n'occluder_full', ...\t\t% 11\n'reflection', ...\t\t% 12\n'crowd' ...\t\t\t% 13\n};\n\"\"\"\n\n\ndef _read_mot_results(filename, is_gt, is_ignore):\n    valid_labels = {1}\n    ignore_labels = {2, 7, 8, 12}\n    results_dict = dict()\n    if os.path.isfile(filename):\n        with open(filename, 'r') as f:\n            for line in f.readlines():\n                linelist = line.split(',')\n                if len(linelist) < 7:\n                    continue\n                fid = int(linelist[0])\n                if fid < 1:\n                    continue\n                results_dict.setdefault(fid, list())\n\n                if is_gt:\n                    if 'MOT16-' in filename or 'MOT17-' in filename:\n                        label = int(float(linelist[7]))\n                        mark = int(float(linelist[6]))\n                        if mark == 0 or label not in valid_labels:\n                            continue\n                    score = 1\n                elif is_ignore:\n                    if 'MOT16-' in filename or 'MOT17-' in filename:\n                        label = int(float(linelist[7]))\n                        vis_ratio = float(linelist[8])\n                        if label not in ignore_labels and vis_ratio >= 0:\n                            continue\n                    else:\n                        continue\n                    score = 1\n                else:\n                    score = float(linelist[6])\n\n                tlwh = tuple(map(float, linelist[2:6]))\n                target_id = int(linelist[1])\n\n                results_dict[fid].append((tlwh, target_id, score))\n\n    return results_dict\n\n\ndef unzip_objs(objs):\n    if len(objs) > 0:\n        tlwhs, ids, scores = zip(*objs)\n    else:\n        tlwhs, ids, scores = [], [], []\n    tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)\n\n    return tlwhs, ids, scores\n"
  },
  {
    "path": "unitrack/utils/log.py",
    "content": "import logging\n\n\ndef get_logger(name='root'):\n    formatter = logging.Formatter(\n        # fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s')\n        fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')\n\n    handler = logging.StreamHandler()\n    handler.setFormatter(formatter)\n\n    logger = logging.getLogger(name)\n    logger.setLevel(logging.DEBUG)\n    logger.addHandler(handler)\n    return logger\n\n\nlogger = get_logger('root')\n"
  },
  {
    "path": "unitrack/utils/mask.py",
    "content": "###################################################################\n# File Name: mask.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Tue Feb  9 10:05:47 2021\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\n\nimport cv2\nimport torch\nimport numpy as np\nimport pycocotools.mask as mask_utils\n\n\ndef coords2bbox(coords, extend=2):\n    \"\"\"\n    INPUTS:\n     - coords: coordinates of pixels in the next frame\n    \"\"\"\n    center = torch.mean(coords, dim=0) # b * 2\n    center = center.view(1,2)\n    center_repeat = center.repeat(coords.size(0),1)\n\n    dis_x = torch.sqrt(torch.pow(coords[:,0] - center_repeat[:,0], 2))\n    dis_x = max(torch.mean(dis_x, dim=0).detach(),1)\n    dis_y = torch.sqrt(torch.pow(coords[:,1] - center_repeat[:,1], 2))\n    dis_y = max(torch.mean(dis_y, dim=0).detach(),1)\n\n    left = center[:,0] - dis_x*extend\n    right = center[:,0] + dis_x*extend\n    top = center[:,1] - dis_y*extend\n    bottom = center[:,1] + dis_y*extend\n\n    return (top.item(), left.item(), bottom.item(), right.item())\n\n\ndef coords2bbox_all(coords):\n    left = coords[:, 0].min().item()\n    top = coords[:, 1].min().item()\n    right = coords[:, 0].max().item()\n    bottom = coords[:, 1].max().item()\n    return top, left, bottom, right\n\n\ndef coords2bboxTensor(coords, extend=2):\n    \"\"\"\n    INPUTS:\n     - coords: coordinates of pixels in the next frame\n    \"\"\"\n    center = torch.mean(coords, dim=0) # b * 2\n    center = center.view(1,2)\n    center_repeat = center.repeat(coords.size(0),1)\n\n    dis_x = torch.sqrt(torch.pow(coords[:,0] - center_repeat[:,0], 2))\n    dis_x = max(torch.mean(dis_x, dim=0).detach(),1)\n    dis_y = torch.sqrt(torch.pow(coords[:,1] - center_repeat[:,1], 2))\n    dis_y = max(torch.mean(dis_y, dim=0).detach(),1)\n\n    left = center[:,0] - dis_x*extend\n    right = center[:,0] + dis_x*extend\n    top = center[:,1] - dis_y*extend\n    bottom = center[:,1] + dis_y*extend\n\n    return torch.Tensor([top.item(), left.item(), bottom.item(), right.item()]).to(coords.device)\n\ndef mask2box(masks):\n    boxes = []\n    for mask in masks:\n        m = mask[0].nonzero().float()\n        if m.numel() > 0:\n            box = coords2bbox(m, extend=2)\n        else:\n            box = (-1,-1,10,10)\n        boxes.append(box)\n    return np.asarray(boxes)\n\ndef tensor_mask2box(masks):\n    boxes = []\n    for mask in masks:\n        m = mask.nonzero().float()\n        if m.numel() > 0:\n            # box = coords2bbox(m, extend=2)\n            box = coords2bbox_all(m)\n        else:\n            box = (-1,-1,10,10)\n        boxes.append(box)\n    return np.asarray(boxes)\n\ndef batch_mask2boxlist(masks):\n    \"\"\"\n    Args:\n        masks: Tensor b,n,h,w\n\n    Returns: List[List[box]]\n\n    \"\"\"\n    batch_bbox = []\n    for i, b_masks in enumerate(masks):\n        boxes = []\n        for mask in b_masks:\n            m = mask.nonzero().float()\n            if m.numel() > 0:\n                box = coords2bboxTensor(m, extend=2)\n            else:\n                box = torch.Tensor([0,0,0,0]).to(m.device)\n            boxes.append(box.unsqueeze(0))\n        boxes_t = torch.cat(boxes, 0)\n        batch_bbox.append(boxes_t)\n\n    return batch_bbox\n\n\ndef bboxlist2roi(bbox_list):\n    \"\"\"Convert a list of bboxes to roi format.\n\n    Args:\n        bbox_list (list[Tensor]): a list of bboxes corresponding to a batch\n            of images.\n\n    Returns:\n        Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]\n    \"\"\"\n    rois_list = []\n    for img_id, bboxes in enumerate(bbox_list):\n        if bboxes.size(0) > 0:\n            img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)\n            rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)\n        else:\n            rois = bboxes.new_zeros((0, 5))\n        rois_list.append(rois)\n    rois = torch.cat(rois_list, 0)\n    return rois\n\ndef bbox2roi(bbox_list):\n    \"\"\"Convert a list of bboxes to roi format.\n\n    Args:\n        bbox_list (list[Tensor]): a list of bboxes corresponding to a batch\n            of images.\n\n    Returns:\n        Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]\n    \"\"\"\n    rois_list = []\n    for img_id, bboxes in enumerate(bbox_list):\n        if bboxes.size(0) > 0:\n            img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)\n            rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)\n        else:\n            rois = bboxes.new_zeros((0, 5))\n        rois_list.append(rois)\n    rois = torch.cat(rois_list, 0)\n    return rois\n\ndef temp_interp_mask(maskseq, T):\n    '''\n    maskseq: list of elements (RLE_mask, timestamp)\n    return list of RLE_mask, length of list is T\n    '''\n    size = maskseq[0][0]['size']\n    blank_mask = np.asfortranarray(np.zeros(size).astype(np.uint8))\n    blank_mask = mask_utils.encode(blank_mask)\n    blank_mask['counts'] = blank_mask['counts'].decode('ascii')\n    ret = [blank_mask,] * T\n    for m, t in maskseq:\n        ret[t] = m\n    return ret\n\ndef mask_seq_jac(sa, sb):\n    j = np.zeros((len(sa), len(sb)))\n    for ia, a in enumerate(sa):\n        for ib, b in enumerate(sb):\n            ious = [mask_utils.iou([at], [bt], [False,]) for (at, bt) in zip(a,b)]\n            tiou = np.mean(ious)\n            j[ia, ib] = tiou\n    return j\n        \n\ndef skltn2mask(skltn, size):\n    h, w = size\n    mask = np.zeros((h,w))\n    \n    dskltn = dict()\n    for s in skltn:\n        dskltn[s['id'][0]] = (int(s['x'][0]), int(s['y'][0]))\n    if len(dskltn)==0:\n        return mask\n    trunk_polygon = list()\n    for k in np.array([3,4,10,13,9])-1:\n        p = dskltn.get(k, None)\n        if not p is None:\n            trunk_polygon.append(p)\n    trunk_polygon = np.asarray(trunk_polygon, 'int32')\n    if len(trunk_polygon) > 2:\n        cv2.fillConvexPoly(mask, trunk_polygon, 1)\n\n    xmin = np.min([dskltn[k][0] for k in dskltn])\n    xmax = np.max([dskltn[k][0] for k in dskltn])\n    ymin = np.min([dskltn[k][1] for k in dskltn])\n    ymax = np.max([dskltn[k][1] for k in dskltn])\n    line_width = np.max([int(np.max([xmax-xmin, ymax-ymin, 0])/20),8])\n\n\n    skeleton = [[10, 11], [11, 12], [9,8], \n                [8,7], [10, 13], [9, 13], \n                [13, 15], [10,4], [4,5], \n                [5,6], [9,3], [3,2], [2,1]]\n    \n\n    for sk in skeleton:\n        st = dskltn.get(sk[0]-1, None)\n        ed = dskltn.get(sk[1]-1, None)\n        if st is None or ed is None:\n            continue\n        cv2.line(mask, st, ed, color=1, thickness=line_width)\n    \n    #dmask = cv2.resize(mask, (w//8, h//8), interpolation=cv2.INTER_NEAREST)\n    #pdb.set_trace()\n    \n    return mask\n\n\ndef pts2array(pts):\n    arr = np.zeros((15,3))\n    for s in pts:\n        arr[s['id'][0]][0] = int(s['x'][0])\n        arr[s['id'][0]][1] = int(s['y'][0])\n        arr[s['id'][0]][2] = s['score'][0]\n    return arr\n"
  },
  {
    "path": "unitrack/utils/meter.py",
    "content": "###################################################################\n# File Name: meter.py\n# Author: Zhongdao Wang\n# mail: wcd17@mails.tsinghua.edu.cn\n# Created Time: Wed Dec 23 16:35:34 2020\n###################################################################\n\nfrom __future__ import print_function\nfrom __future__ import division\nfrom __future__ import absolute_import\nimport time\n\n\nclass Timer(object):\n    \"\"\"A simple timer.\"\"\"\n    def __init__(self):\n        self.total_time = 0.\n        self.calls = 0\n        self.start_time = 0.\n        self.diff = 0.\n        self.average_time = 0.\n\n        self.duration = 0.\n\n    def tic(self):\n        # using time.time instead of time.clock because time time.clock\n        # does not normalize for multithreading\n        self.start_time = time.time()\n\n    def toc(self, average=True):\n        self.diff = time.time() - self.start_time\n        self.total_time += self.diff\n        self.calls += 1\n        self.average_time = self.total_time / self.calls\n        if average:\n            self.duration = self.average_time\n        else:\n            self.duration = self.diff\n        return self.duration\n\n    def clear(self):\n        self.total_time = 0.\n        self.calls = 0\n        self.start_time = 0.\n        self.diff = 0.\n        self.average_time = 0.\n        self.duration = 0.\n\n"
  },
  {
    "path": "unitrack/utils/palette.py",
    "content": "palette_str = '''0 0 0\n128 0 0\n0 128 0\n128 128 0\n0 0 128\n128 0 128\n0 128 128\n128 128 128\n64 0 0\n191 0 0\n64 128 0\n191 128 0\n64 0 128\n191 0 128\n64 128 128\n191 128 128\n0 64 0\n128 64 0\n0 191 0\n128 191 0\n0 64 128\n128 64 128\n22 22 22\n23 23 23\n24 24 24\n25 25 25\n26 26 26\n27 27 27\n28 28 28\n29 29 29\n30 30 30\n31 31 31\n32 32 32\n33 33 33\n34 34 34\n35 35 35\n36 36 36\n37 37 37\n38 38 38\n39 39 39\n40 40 40\n41 41 41\n42 42 42\n43 43 43\n44 44 44\n45 45 45\n46 46 46\n47 47 47\n48 48 48\n49 49 49\n50 50 50\n51 51 51\n52 52 52\n53 53 53\n54 54 54\n55 55 55\n56 56 56\n57 57 57\n58 58 58\n59 59 59\n60 60 60\n61 61 61\n62 62 62\n63 63 63\n64 64 64\n65 65 65\n66 66 66\n67 67 67\n68 68 68\n69 69 69\n70 70 70\n71 71 71\n72 72 72\n73 73 73\n74 74 74\n75 75 75\n76 76 76\n77 77 77\n78 78 78\n79 79 79\n80 80 80\n81 81 81\n82 82 82\n83 83 83\n84 84 84\n85 85 85\n86 86 86\n87 87 87\n88 88 88\n89 89 89\n90 90 90\n91 91 91\n92 92 92\n93 93 93\n94 94 94\n95 95 95\n96 96 96\n97 97 97\n98 98 98\n99 99 99\n100 100 100\n101 101 101\n102 102 102\n103 103 103\n104 104 104\n105 105 105\n106 106 106\n107 107 107\n108 108 108\n109 109 109\n110 110 110\n111 111 111\n112 112 112\n113 113 113\n114 114 114\n115 115 115\n116 116 116\n117 117 117\n118 118 118\n119 119 119\n120 120 120\n121 121 121\n122 122 122\n123 123 123\n124 124 124\n125 125 125\n126 126 126\n127 127 127\n128 128 128\n129 129 129\n130 130 130\n131 131 131\n132 132 132\n133 133 133\n134 134 134\n135 135 135\n136 136 136\n137 137 137\n138 138 138\n139 139 139\n140 140 140\n141 141 141\n142 142 142\n143 143 143\n144 144 144\n145 145 145\n146 146 146\n147 147 147\n148 148 148\n149 149 149\n150 150 150\n151 151 151\n152 152 152\n153 153 153\n154 154 154\n155 155 155\n156 156 156\n157 157 157\n158 158 158\n159 159 159\n160 160 160\n161 161 161\n162 162 162\n163 163 163\n164 164 164\n165 165 165\n166 166 166\n167 167 167\n168 168 168\n169 169 169\n170 170 170\n171 171 171\n172 172 172\n173 173 173\n174 174 174\n175 175 175\n176 176 176\n177 177 177\n178 178 178\n179 179 179\n180 180 180\n181 181 181\n182 182 182\n183 183 183\n184 184 184\n185 185 185\n186 186 186\n187 187 187\n188 188 188\n189 189 189\n190 190 190\n191 191 191\n192 192 192\n193 193 193\n194 194 194\n195 195 195\n196 196 196\n197 197 197\n198 198 198\n199 199 199\n200 200 200\n201 201 201\n202 202 202\n203 203 203\n204 204 204\n205 205 205\n206 206 206\n207 207 207\n208 208 208\n209 209 209\n210 210 210\n211 211 211\n212 212 212\n213 213 213\n214 214 214\n215 215 215\n216 216 216\n217 217 217\n218 218 218\n219 219 219\n220 220 220\n221 221 221\n222 222 222\n223 223 223\n224 224 224\n225 225 225\n226 226 226\n227 227 227\n228 228 228\n229 229 229\n230 230 230\n231 231 231\n232 232 232\n233 233 233\n234 234 234\n235 235 235\n236 236 236\n237 237 237\n238 238 238\n239 239 239\n240 240 240\n241 241 241\n242 242 242\n243 243 243\n244 244 244\n245 245 245\n246 246 246\n247 247 247\n248 248 248\n249 249 249\n250 250 250\n251 251 251\n252 252 252\n253 253 253\n254 254 254\n255 255 255'''\nimport numpy as np\ntensor = np.array([[int(x) for x in line.split()] for line in palette_str.split('\\n')])\n"
  },
  {
    "path": "unitrack/utils/visualize.py",
    "content": "\nimport cv2\nimport numpy as np\nimport imageio as io\nfrom matplotlib import cm\n\nimport time\nimport PIL\n\nimport pycocotools.mask as mask_utils\nfrom . import palette\n\n\ndef dump_predictions(pred, lbl_set, img, prefix):\n    '''\n    Save:\n        1. Predicted labels for evaluation\n        2. Label heatmaps for visualization\n    '''\n    lbl_set = palette.tensor.astype(np.uint8)\n    sz = img.shape[:-1]\n\n    # Upsample predicted soft label maps\n    # pred_dist = pred.copy()\n    pred_dist = cv2.resize(pred, sz[::-1])[:]\n    \n    # Argmax to get the hard label for index\n    pred_lbl = np.argmax(pred_dist, axis=-1)\n    pred_lbl = np.array(lbl_set, dtype=np.int32)[pred_lbl]      \n    mask = np.float32(pred_lbl.sum(2) > 0)[:,:,None]\n    alpha = 0.5\n    img_with_label = mask * (np.float32(img) * alpha + \\\n            np.float32(pred_lbl) * (1-alpha)) + (1-mask) * np.float32(img)\n\n    # Visualize label distribution for object 1 (debugging/analysis)\n    pred_soft = pred_dist[..., 1]\n    pred_soft = cv2.resize(pred_soft, (img.shape[1], img.shape[0]), \n            interpolation=cv2.INTER_NEAREST)\n    pred_soft = cm.jet(pred_soft)[..., :3] * 255.0\n    img_with_heatmap1 =  np.float32(img) * 0.5 + np.float32(pred_soft) * 0.5\n\n    # Save blend image for visualization\n    io.imwrite('%s_blend.jpg' % prefix, np.uint8(img_with_label))\n\n    if prefix[-4] != '.':  # Super HACK-y\n        imname2 = prefix + '_mask.png'\n    else:\n        imname2 = prefix.replace('jpg','png')\n\n    # Save predicted labels for evaluation\n    io.imwrite(imname2, np.uint8(pred_lbl))\n\n    return img_with_label, pred_lbl, img_with_heatmap1\n\n\n\ndef make_gif(video, outname='/tmp/test.gif', sz=256):\n    if hasattr(video, 'shape'):\n        video = video.cpu()\n        if video.shape[0] == 3:\n            video = video.transpose(0, 1)\n\n        video = video.numpy().transpose(0, 2, 3, 1)\n        video = (video*255).astype(np.uint8)\n        \n    video = [cv2.resize(vv, (sz, sz)) for vv in video]\n\n    if outname is None:\n        return np.stack(video)\n\n    io.mimsave(outname, video, duration = 0.2)\n\ndef get_color(idx):\n    idx = idx * 17\n    color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)\n    return color\n\ndef plot_tracking(image, obs, obj_ids, scores=None, frame_id=0, fps=0.):\n    im = np.ascontiguousarray(np.copy(image))\n    im_h, im_w = im.shape[:2]\n\n    text_scale = max(1, image.shape[1] / 1600.)\n    text_thickness = 1 if text_scale > 1.1 else 1\n    line_thickness = max(1, int(image.shape[1] / 150.))\n    alpha = 0.4\n\n    for i, ob in enumerate(obs): \n        obj_id = int(obj_ids[i])\n        id_text = '{}'.format(int(obj_id))\n        _line_thickness = 1 if obj_id <= 0 else line_thickness\n        color = get_color(obj_id)\n        if len(ob) == 4:\n            x1, y1, w, h = ob\n            intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))\n            cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)\n            cv2.putText(im, id_text, (intbox[0], intbox[1] + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255),\n                        thickness=text_thickness)\n        elif isinstance(ob, dict):\n            mask = mask_utils.decode(ob)\n            mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)\n            mask = (mask > 0.5).astype(np.uint8)[:,:,None]\n            mask_color = mask * color\n            im = (1 - mask) * im + mask * (alpha*im + (1-alpha)*mask_color) \n        else:\n            raise ValueError('Observation format not supported.')\n    return im\n\n\ndef vis_pose(oriImg, points):\n\n    pa = np.zeros(15)\n    pa[2] = 0\n    pa[12] = 8\n    pa[8] = 4\n    pa[4] = 0\n    pa[11] = 7\n    pa[7] = 3\n    pa[3] = 0\n    pa[0] = 1\n    pa[14] = 10\n    pa[10] = 6\n    pa[6] = 1\n    pa[13] = 9\n    pa[9] = 5\n    pa[5] = 1\n\n    colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],\n              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],\n              [170,0,255],[255,0,255]]\n    canvas = oriImg\n    stickwidth = 4\n    x = points[0, :]\n    y = points[1, :]\n\n    for n in range(len(x)):\n        pair_id = int(pa[n])\n\n        x1 = int(x[pair_id])\n        y1 = int(y[pair_id])\n        x2 = int(x[n])\n        y2 = int(y[n])\n\n        if x1 >= 0 and y1 >= 0 and x2 >= 0 and y2 >= 0:\n            cv2.line(canvas, (x1, y1), (x2, y2), colors[n], 8)\n\n    return canvas\n\n\ndef draw_skeleton(aa, kp, color, show_skeleton_labels=False, dataset= \"PoseTrack\"):\n    if dataset == \"COCO\":\n        skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], \n                [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], \n                [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]\n        kp_names = ['nose', 'l_eye', 'r_eye', 'l_ear', 'r_ear', 'l_shoulder',\n                    'r_shoulder', 'l_elbow', 'r_elbow', 'l_wrist', 'r_wrist',\n                    'l_hip', 'r_hip', 'l_knee', 'r_knee', 'l_ankle', 'r_ankle']\n    elif dataset == \"PoseTrack\":\n        skeleton = [[10, 11], [11, 12], [9,8], [8,7],\n                    [10, 13], [9, 13], [13, 15], [10,4],\n                    [4,5], [5,6], [9,3], [3,2], [2,1]]\n        kp_names = ['right_ankle', 'right_knee', 'right_pelvis',\n                    'left_pelvis', 'left_knee', 'left_ankle',\n                    'right_wrist', 'right_elbow', 'right_shoulder',\n                    'left_shoulder', 'left_elbow', 'left_wrist',\n                    'upper_neck', 'nose', 'head']\n    for i, j in skeleton:\n        if kp[i-1][0] >= 0 and kp[i-1][1] >= 0 and kp[j-1][0] >= 0 and kp[j-1][1] >= 0 and \\\n            (len(kp[i-1]) <= 2 or (len(kp[i-1]) > 2 and  kp[i-1][2] > 0.1 and kp[j-1][2] > 0.1)):\n            st = (int(kp[i-1][0]), int(kp[i-1][1]))\n            ed = (int(kp[j-1][0]), int(kp[j-1][1]))\n            cv2.line(aa, st, ed,  color, max(1, int(aa.shape[1]/150.)))\n    for j in range(len(kp)):\n        if kp[j][0] >= 0 and kp[j][1] >= 0:\n            pt = (int(kp[j][0]), int(kp[j][1]))\n            if len(kp[j]) <= 2 or (len(kp[j]) > 2 and kp[j][2] > 1.1):\n                cv2.circle(aa, pt, 2, tuple((0,0,255)), 2)\n            elif len(kp[j]) <= 2 or (len(kp[j]) > 2 and kp[j][2] > 0.1):\n                cv2.circle(aa, pt, 2, tuple((255,0,0)), 2)\n\n            if show_skeleton_labels and (len(kp[j]) <= 2 or (len(kp[j]) > 2 and kp[j][2] > 0.1)):\n                cv2.putText(aa, kp_names[j], tuple(kp[j][:2]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0))\n"
  }
]