[
  {
    "path": ".gitignore",
    "content": ".eggs/*\n.vscode/*\nwork_dirs/*\nwork_dir/*\npretrained/*\nckpt/*\nrunai_dataset/*\n*/__pycache__\n*.pyc\ndata/*\ndata\noutput/*\n.idea/*\npose_anything.egg-info/*"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2022 SenseTime. All Rights Reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2020 MMClassification Authors.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": ":new: *Please check out [EdgeCape](https://github.com/orhir/EdgeCape), our more recent effort in the same line of work.*\n<br /> <br />\n\n# A Graph-Based Approach for Category-Agnostic Pose Estimation [ECCV 2024]\n<a href=\"https://orhir.github.io/pose-anything/\"><img src=\"https://img.shields.io/static/v1?label=Project&message=Website&color=blue\"></a>\n<a href=\"https://arxiv.org/abs/2311.17891\"><img src=\"https://img.shields.io/badge/arXiv-2311.17891-b31b1b.svg\"></a>\n<a href=\"https://www.apache.org/licenses/LICENSE-2.0.txt\"><img src=\"https://img.shields.io/badge/License-Apache-yellow\"></a>\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/orhir/PoseAnything)\n[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/orhir/Pose-Anything)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pose-anything-a-graph-based-approach-for/2d-pose-estimation-on-mp-100)](https://paperswithcode.com/sota/2d-pose-estimation-on-mp-100?p=pose-anything-a-graph-based-approach-for)\n\nBy [Or Hirschorn](https://scholar.google.co.il/citations?user=GgFuT_QAAAAJ&hl=iw&oi=ao) and [Shai Avidan](https://scholar.google.co.il/citations?hl=iw&user=hpItE1QAAAAJ)\n\nThis repo is the official implementation of \"[A Graph-Based Approach for Category-Agnostic Pose Estimation](https://arxiv.org/pdf/2311.17891.pdf)\".\n\n<p align=\"center\">\n<img src=\"Pose_Anything_Teaser.png\" width=\"384\">\n</p>\n\n## 🔔 News\n- **`11 July 2024`** Our paper will be presented at **ECCV 2024**.\n- **`10 July 2024`** Uploaded new annotations - fix a small bug of DeepFashion skeletons.\n- **`2 Feburary 2024`** Uploaded new weights - smaller models with stronger performance.\n- **`20 December 2023`** Demo is online on [Huggingface](https://huggingface.co/spaces/orhir/PoseAnything) and [OpenXLab](https://openxlab.org.cn/apps/detail/orhir/Pose-Anything).\n- **`7 December 2023`** Official code release.\n\n## Introduction\n\nWe present a novel approach to CAPE that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities.\n\n## Citation\nIf you find this useful, please cite this work as follows:\n```bibtex\n@misc{hirschorn2023pose,\n      title={A Graph-Based Approach for Category-Agnostic Pose Estimation}, \n      author={Or Hirschorn and Shai Avidan},\n      year={2024},\n      eprint={2311.17891},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2311.17891}, \n}\n```\n\n## Getting Started\n\n### Docker [Recommended]\nWe provide a docker image for easy use.\nYou can simply pull the docker image from docker hub, containing all the required libraries and packages:\n\n```\ndocker pull orhir/pose_anything\ndocker run --name pose_anything -v {DATA_DIR}:/workspace/PoseAnything/PoseAnything/data/mp100 -it orhir/pose_anything /bin/bash\n```\n### Conda Environment\nWe train and evaluate our model on Python 3.8 and Pytorch 2.0.1 with CUDA 12.1. \n\nPlease first install pytorch and torchvision following official documentation Pytorch. \nThen, follow [MMPose](https://mmpose.readthedocs.io/en/latest/installation.html) to install the following packages:\n```\nmmcv-full=1.6.2\nmmpose=0.29.0\n```\nHaving installed these packages, run:\n```\npython setup.py develop\n```\n\n## Demo on Custom Images\n<i>TRY IT NOW ON:</i> <a href=\"https://huggingface.co/spaces/orhir/PoseAnything\">HuggingFace</a> / <a href=\"https://openxlab.org.cn/apps/detail/orhir/Pose-Anything\">OpenXLab</a>\n\n\nWe provide a demo code to test our code on custom images. \n\n### Gradio Demo\nWe first require to install gradio:\n```\npip install gradio==3.44.0\n```\nThen, Download the [pretrained model](https://drive.google.com/file/d/1RT1Q8AMEa1kj6k9ZqrtWIKyuR4Jn4Pqc/view?usp=drive_link) and run:\n```\npython app.py --checkpoint [path_to_pretrained_ckpt]\n```\n### Terminal Demo\nDownload\nthe [pretrained model](https://drive.google.com/file/d/1RT1Q8AMEa1kj6k9ZqrtWIKyuR4Jn4Pqc/view?usp=drive_link)\nand run:\n\n```\npython demo.py --support [path_to_support_image] --query [path_to_query_image] --config configs/demo_b.py --checkpoint [path_to_pretrained_ckpt]\n```\n***Note:*** The demo code supports any config with suitable checkpoint file. More pre-trained models can be found in the evaluation section.\n\n\n## Updated MP-100 Dataset\nPlease follow the [official guide](https://github.com/luminxu/Pose-for-Everything/blob/main/mp100/README.md) to prepare the MP-100 dataset for training and evaluation, and organize the data structure properly.\n\n**We provide an updated annotation file, which includes skeleton definitions, in the following [link](https://drive.google.com/drive/folders/1uRyGB-P5Tc_6TmAZ6RnOi0SWjGq9b28T?usp=sharing).**\n\n**Please note:**\n\nCurrent version of the MP-100 dataset includes some discrepancies and filenames errors:\n1. Note that the mentioned DeepFasion dataset is actually DeepFashion2 dataset. The link in the official repo is wrong. Use this [repo](https://github.com/switchablenorms/DeepFashion2/tree/master) instead.\n2. We provide a script to fix CarFusion filename errors, which can be run by:\n```\npython tools/fix_carfusion.py [path_to_CarFusion_dataset] [path_to_mp100_annotation]\n```\n\n## Training\n\n### Backbone Options\nTo use pre-trained Swin-Transformer as used in our paper, we provide the weights, taken from this [repo](https://github.com/microsoft/Swin-Transformer/blob/main/MODELHUB.md), in the following [link](https://drive.google.com/drive/folders/1-q4mSxlNAUwDlevc3Hm5Ij0l_2OGkrcg?usp=sharing).\nThese should be placed in the `./pretrained` folder.\n\nWe also support DINO and ResNet backbones. To use them, you can easily change the config file to use the desired backbone.\nThis can be done by changing the `pretrained` field in the config file to `dinov2`, `dino` or `resnet` respectively (this will automatically load the pretrained weights from the official repo).\n\n### Training\nTo train the model, run:\n```\npython train.py --config [path_to_config_file]  --work-dir [path_to_work_dir]\n```\n\n## Evaluation and Pretrained Models\nYou can download the pretrained checkpoints from following [link](https://drive.google.com/drive/folders/1RmrqzE3g0qYRD5xn54-aXEzrIkdYXpEW?usp=sharing).\n\nHere we provide the evaluation results of our pretrained models on MP-100 dataset along with the config files and checkpoints:\n\n### 1-Shot Models\n| Setting |                                                                       split 1                                                                       |                                                                       split 2                                                                       |                                                                       split 3                                                                       |                                                                       split 4                                                                       |                                                                       split 5                                                                       |\n|:-------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|\n|  Tiny   |                                                                        91.19                                                                        |                                                                        87.81                                                                        |                                                                        85.68                                                                        |                                                                        85.87                                                                        |                                                                        85.61                                                                        |\n|         |   [link](https://drive.google.com/file/d/1GubmkVkqybs-eD4hiRkgBzkUVGE_rIFX/view?usp=drive_link) / [config](configs/1shots/graph_split1_config.py)   |   [link](https://drive.google.com/file/d/1EEekDF3xV_wJOVk7sCQWUA8ygUKzEm2l/view?usp=drive_link) / [config](configs/1shots/graph_split2_config.py)   |   [link](https://drive.google.com/file/d/1FuwpNBdPI3mfSovta2fDGKoqJynEXPZQ/view?usp=drive_link) / [config](configs/1shots/graph_split3_config.py)   |   [link](https://drive.google.com/file/d/1_SSqSANuZlbC0utzIfzvZihAW9clefcR/view?usp=drive_link) / [config](configs/1shots/graph_split4_config.py)   |   [link](https://drive.google.com/file/d/1nUHr07W5F55u-FKQEPFq_CECgWZOKKLF/view?usp=drive_link) / [config](configs/1shots/graph_split5_config.py)   |\n|  Small  |                                                                        94.73                                                                        |                                                                        89.79                                                                        |                                                                        90.69                                                                        |                                                                        88.09                                                                        |                                                                        90.11                                                                        |\n|         | [link](https://drive.google.com/file/d/1RT1Q8AMEa1kj6k9ZqrtWIKyuR4Jn4Pqc/view?usp=drive_link) / [config](configs/1shot-swin/graph_split1_config.py) | [link](https://drive.google.com/file/d/1BT5b8MlnkflcdhTFiBROIQR3HccLsPQd/view?usp=drive_link) / [config](configs/1shot-swin/graph_split2_config.py) | [link](https://drive.google.com/file/d/1Z64cw_1CSDGObabSAWKnMK0BA_bqDHxn/view?usp=drive_link) / [config](configs/1shot-swin/graph_split3_config.py) | [link](https://drive.google.com/file/d/1vf82S8LAjIzpuBcbEoDCa26cR8DqNriy/view?usp=drive_link) / [config](configs/1shot-swin/graph_split4_config.py) | [link](https://drive.google.com/file/d/14FNx0JNbkS2CvXQMiuMU_kMZKFGO2rDV/view?usp=drive_link) / [config](configs/1shot-swin/graph_split5_config.py) |\n\n### 5-Shot Models\n| Setting |                                                                       split 1                                                                       |                                                                       split 2                                                                       |                                                                       split 3                                                                       |                                                                       split 4                                                                       |                                                                       split 5                                                                       |\n|:-------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|\n|  Tiny   |                                                                        94.24                                                                        |                                                                        91.32                                                                        |                                                                        90.15                                                                        |                                                                        90.37                                                                        |                                                                        89.73                                                                        |\n|         |   [link](https://drive.google.com/file/d/1PeMuwv5YwiF3UCE5oN01Qchu5K3BaQ9L/view?usp=drive_link) / [config](configs/5shots/graph_split1_config.py)   |   [link](https://drive.google.com/file/d/1enIapPU1D8lZOET7q_qEjnhC1HFy3jWK/view?usp=drive_link) / [config](configs/5shots/graph_split2_config.py)   |   [link](https://drive.google.com/file/d/1MTeZ9Ba-ucLuqX0KBoLbBD5PaEct7VUp/view?usp=drive_link) / [config](configs/5shots/graph_split3_config.py)   |   [link](https://drive.google.com/file/d/1U2N7DI2F0v7NTnPCEEAgx-WKeBZNAFoa/view?usp=drive_link) / [config](configs/5shots/graph_split4_config.py)   |   [link](https://drive.google.com/file/d/1wapJDgtBWtmz61JNY7ktsFyvckRKiR2C/view?usp=drive_link) / [config](configs/5shots/graph_split5_config.py)   |\n|  Small  |                                                                        96.67                                                                        |                                                                        91.48                                                                        |                                                                        92.62                                                                        |                                                                        90.95                                                                        |                                                                        92.41                                                                        |\n|         | [link](https://drive.google.com/file/d/1p5rnA0MhmndSKEbyXMk49QXvNE03QV2p/view?usp=drive_link) / [config](configs/5shot-swin/graph_split1_config.py) | [link](https://drive.google.com/file/d/1Q3KNyUW_Gp3JytYxUPhkvXFiDYF6Hv8w/view?usp=drive_link) / [config](configs/5shot-swin/graph_split2_config.py) | [link](https://drive.google.com/file/d/1gWgTk720fSdAf_ze1FkfXTW0t7k-69dV/view?usp=drive_link) / [config](configs/5shot-swin/graph_split3_config.py) | [link](https://drive.google.com/file/d/1LuaRQ8a6AUPrkr7l5j2W6Fe_QbgASkwY/view?usp=drive_link) / [config](configs/5shot-swin/graph_split4_config.py) | [link](https://drive.google.com/file/d/1z--MAOPCwMG_GQXru9h2EStbnIvtHv1L/view?usp=drive_link) / [config](configs/5shot-swin/graph_split5_config.py) |\n\n### Evaluation\nThe evaluation on a single GPU will take approximately 30 min. \n\nTo evaluate the pretrained model, run:\n```\npython test.py [path_to_config_file] [path_to_pretrained_ckpt]\n```\n## Acknowledgement\n\nOur code is based on code from:\n - [MMPose](https://github.com/open-mmlab/mmpose)\n - [CapeFormer](https://github.com/flyinglynx/CapeFormer)\n\n\n## License\nThis project is released under the Apache 2.0 license.\n"
  },
  {
    "path": "app.py",
    "content": "import argparse\nimport random\n\nimport gradio as gr\nimport matplotlib\nimport numpy as np\nimport torch\nfrom PIL import ImageDraw, Image\nfrom matplotlib import pyplot as plt\nfrom mmcv import Config\nfrom mmcv.runner import load_checkpoint\nfrom mmpose.core import wrap_fp16_model\nfrom mmpose.models import build_posenet\nfrom torchvision import transforms\n\nfrom demo import Resize_Pad\nfrom models import *\n\n# Copyright (c) OpenMMLab. All rights reserved.\n# os.system('python -m pip install timm')\n# os.system('python -m pip install Openmim')\n# os.system('python -m mim install mmengine')\n# os.system('python -m mim install \"mmcv-full==1.6.2\"')\n# os.system('python -m mim install \"mmpose==0.29.0\"')\n# os.system('python -m mim install \"gradio==3.44.0\"')\n# os.system('python setup.py develop')\n\nmatplotlib.use('agg')\ncheckpoint_path = ''\n\n\ndef plot_results(support_img, query_img, support_kp, support_w, query_kp,\n                 query_w, skeleton,\n                 initial_proposals, prediction, radius=6):\n    h, w, c = support_img.shape\n    prediction = prediction[-1].cpu().numpy() * h\n    query_img = (query_img - np.min(query_img)) / (\n            np.max(query_img) - np.min(query_img))\n    for id, (img, w, keypoint) in enumerate(zip([query_img],\n                                                [query_w],\n                                                [prediction])):\n        f, axes = plt.subplots()\n        plt.imshow(img)\n        for k in range(keypoint.shape[0]):\n            if w[k] > 0:\n                kp = keypoint[k, :2]\n                c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)\n                patch = plt.Circle(kp, radius, color=c)\n                axes.add_patch(patch)\n                axes.text(kp[0], kp[1], k)\n                plt.draw()\n        for l, limb in enumerate(skeleton):\n            kp = keypoint[:, :2]\n            if l > len(COLORS) - 1:\n                c = [x / 255 for x in random.sample(range(0, 255), 3)]\n            else:\n                c = [x / 255 for x in COLORS[l]]\n            if w[limb[0]] > 0 and w[limb[1]] > 0:\n                patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]],\n                                   [kp[limb[0], 1], kp[limb[1], 1]],\n                                   linewidth=6, color=c, alpha=0.6)\n                axes.add_artist(patch)\n        plt.axis('off')  # command for hiding the axis.\n        plt.subplots_adjust(0, 0, 1, 1, 0, 0)\n        return plt\n\n\nCOLORS = [\n    [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],\n    [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],\n    [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],\n    [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]\n]\n\n\ndef process(query_img, state,\n            cfg_path='configs/demo_b.py'):\n    cfg = Config.fromfile(cfg_path)\n    width, height, _ = state['original_support_image'].shape\n    kp_src_np = np.array(state['kp_src']).copy().astype(np.float32)\n    kp_src_np[:, 0] = kp_src_np[:, 0] / (\n            width // 4) * cfg.model.encoder_config.img_size\n    kp_src_np[:, 1] = kp_src_np[:, 1] / (\n            height // 4) * cfg.model.encoder_config.img_size\n    kp_src_np = np.flip(kp_src_np, 1).copy()\n    kp_src_tensor = torch.tensor(kp_src_np).float()\n    preprocess = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n        Resize_Pad(cfg.model.encoder_config.img_size,\n                   cfg.model.encoder_config.img_size)])\n\n    if len(state['skeleton']) == 0:\n        state['skeleton'] = [(0, 0)]\n\n    support_img = preprocess(state['original_support_image']).flip(0)[None]\n    np_query = np.array(query_img)[:, :, ::-1].copy()\n    q_img = preprocess(np_query).flip(0)[None]\n    # Create heatmap from keypoints\n    genHeatMap = TopDownGenerateTargetFewShot()\n    data_cfg = cfg.data_cfg\n    data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size,\n                                       cfg.model.encoder_config.img_size])\n    data_cfg['joint_weights'] = None\n    data_cfg['use_different_joint_weights'] = False\n    kp_src_3d = torch.cat(\n        (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)\n    kp_src_3d_weight = torch.cat(\n        (torch.ones_like(kp_src_tensor),\n         torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)\n    target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,\n                                                                 kp_src_3d,\n                                                                 kp_src_3d_weight,\n                                                                 sigma=1)\n    target_s = torch.tensor(target_s).float()[None]\n    target_weight_s = torch.ones_like(\n        torch.tensor(target_weight_s).float()[None])\n\n    data = {\n        'img_s': [support_img],\n        'img_q': q_img,\n        'target_s': [target_s],\n        'target_weight_s': [target_weight_s],\n        'target_q': None,\n        'target_weight_q': None,\n        'return_loss': False,\n        'img_metas': [{'sample_skeleton': [state['skeleton']],\n                       'query_skeleton': state['skeleton'],\n                       'sample_joints_3d': [kp_src_3d],\n                       'query_joints_3d': kp_src_3d,\n                       'sample_center': [kp_src_tensor.mean(dim=0)],\n                       'query_center': kp_src_tensor.mean(dim=0),\n                       'sample_scale': [\n                           kp_src_tensor.max(dim=0)[0] -\n                           kp_src_tensor.min(dim=0)[0]],\n                       'query_scale': kp_src_tensor.max(dim=0)[0] -\n                                      kp_src_tensor.min(dim=0)[0],\n                       'sample_rotation': [0],\n                       'query_rotation': 0,\n                       'sample_bbox_score': [1],\n                       'query_bbox_score': 1,\n                       'query_image_file': '',\n                       'sample_image_file': [''],\n                       }]\n    }\n    # Load model\n    model = build_posenet(cfg.model)\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    load_checkpoint(model, checkpoint_path, map_location='cpu')\n    model.eval()\n    with torch.no_grad():\n        outputs = model(**data)\n    # visualize results\n    vis_s_weight = target_weight_s[0]\n    vis_q_weight = target_weight_s[0]\n    vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)\n    vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0)\n    support_kp = kp_src_3d\n    out = plot_results(vis_s_image,\n                       vis_q_image,\n                       support_kp,\n                       vis_s_weight,\n                       None,\n                       vis_q_weight,\n                       state['skeleton'],\n                       None,\n                       torch.tensor(outputs['points']).squeeze(0),\n                       )\n    return out, state\n\n\ndef update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02):\n    state['color_idx'] = 0\n    state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()\n    support_img, posed_support, _ = set_query(support_img, state, example=True)\n    w, h = support_img.size\n    draw_pose = ImageDraw.Draw(support_img)\n    draw_limb = ImageDraw.Draw(posed_support)\n    r = int(r * w)\n    width = int(width * w)\n    for pixel in state['kp_src']:\n        leftUpPoint = (pixel[1] - r, pixel[0] - r)\n        rightDownPoint = (pixel[1] + r, pixel[0] + r)\n        twoPointList = [leftUpPoint, rightDownPoint]\n        draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))\n        draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))\n    for limb in state['skeleton']:\n        point_a = state['kp_src'][limb[0]][::-1]\n        point_b = state['kp_src'][limb[1]][::-1]\n        if state['color_idx'] < len(COLORS):\n            c = COLORS[state['color_idx']]\n            state['color_idx'] += 1\n        else:\n            c = random.choices(range(256), k=3)\n        draw_limb.line([point_a, point_b], fill=tuple(c), width=width)\n    return support_img, posed_support, query_img, state\n\n\ndef get_select_coords(kp_support,\n                      limb_support,\n                      state,\n                      evt: gr.SelectData,\n                      r=0.015):\n    pixels_in_queue = set()\n    pixels_in_queue.add((evt.index[1], evt.index[0]))\n    while len(pixels_in_queue) > 0:\n        pixel = pixels_in_queue.pop()\n        if pixel[0] is not None and pixel[1] is not None and pixel not in \\\n                state['kp_src']:\n            state['kp_src'].append(pixel)\n        else:\n            continue\n        if limb_support is None:\n            canvas_limb = kp_support\n        else:\n            canvas_limb = limb_support\n        canvas_kp = kp_support\n        w, h = canvas_kp.size\n        draw_pose = ImageDraw.Draw(canvas_kp)\n        draw_limb = ImageDraw.Draw(canvas_limb)\n        r = int(r * w)\n        leftUpPoint = (pixel[1] - r, pixel[0] - r)\n        rightDownPoint = (pixel[1] + r, pixel[0] + r)\n        twoPointList = [leftUpPoint, rightDownPoint]\n        draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))\n        draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))\n    return canvas_kp, canvas_limb, state\n\n\ndef get_limbs(kp_support,\n              state,\n              evt: gr.SelectData,\n              r=0.02, width=0.02):\n    curr_pixel = (evt.index[1], evt.index[0])\n    pixels_in_queue = set()\n    pixels_in_queue.add((evt.index[1], evt.index[0]))\n    canvas_kp = kp_support\n    w, h = canvas_kp.size\n    r = int(r * w)\n    width = int(width * w)\n    while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']:\n        pixel = pixels_in_queue.pop()\n        state['prev_clicked'] = pixel\n        closest_point = min(state['kp_src'],\n                            key=lambda p: (p[0] - pixel[0]) ** 2 +\n                                          (p[1] - pixel[1]) ** 2)\n        closest_point_index = state['kp_src'].index(closest_point)\n        draw_limb = ImageDraw.Draw(canvas_kp)\n        if state['color_idx'] < len(COLORS):\n            c = COLORS[state['color_idx']]\n        else:\n            c = random.choices(range(256), k=3)\n        leftUpPoint = (closest_point[1] - r, closest_point[0] - r)\n        rightDownPoint = (closest_point[1] + r, closest_point[0] + r)\n        twoPointList = [leftUpPoint, rightDownPoint]\n        draw_limb.ellipse(twoPointList, fill=tuple(c))\n        if state['count'] == 0:\n            state['prev_pt'] = closest_point[1], closest_point[0]\n            state['prev_pt_idx'] = closest_point_index\n            state['count'] = state['count'] + 1\n        else:\n            if state['prev_pt_idx'] != closest_point_index:\n                # Create Line and add Limb\n                draw_limb.line(\n                    [state['prev_pt'], (closest_point[1], closest_point[0])],\n                    fill=tuple(c),\n                    width=width)\n                state['skeleton'].append(\n                    (state['prev_pt_idx'], closest_point_index))\n                state['color_idx'] = state['color_idx'] + 1\n            else:\n                draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))\n            state['count'] = 0\n    return canvas_kp, state\n\n\ndef set_query(support_img, state, example=False):\n    if not example:\n        state['skeleton'].clear()\n        state['kp_src'].clear()\n    state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()\n    width, height = support_img.size\n    support_img = support_img.resize((width // 4, width // 4),\n                                     Image.Resampling.LANCZOS)\n    return support_img, support_img, state\n\n\nwith gr.Blocks() as demo:\n    state = gr.State({\n        'kp_src': [],\n        'skeleton': [],\n        'count': 0,\n        'color_idx': 0,\n        'prev_pt': None,\n        'prev_pt_idx': None,\n        'prev_clicked': None,\n        'original_support_image': None,\n    })\n\n    gr.Markdown('''\n    # Pose Anything Demo\n    We present a novel approach to category agnostic pose estimation that \n    leverages the inherent geometrical relations between keypoints through a \n    newly designed Graph Transformer Decoder. By capturing and incorporating \n    this crucial structural information, our method enhances the accuracy of \n    keypoint localization, marking a significant departure from conventional \n    CAPE techniques that treat keypoints as isolated entities.\n    ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything) \n    ## Instructions\n    1. Upload an image of the object you want to pose on the **left** image.\n    2. Click on the **left** image to mark keypoints.\n    3. Click on the keypoints on the **right** image to mark limbs.\n    4. Upload an image of the object you want to pose to the query image (\n    **bottom**).\n    5. Click **Evaluate** to pose the query image.\n    ''')\n    with gr.Row():\n        support_img = gr.Image(label=\"Support Image\",\n                               type=\"pil\",\n                               info='Click to mark keypoints').style(\n            height=400, width=400)\n        posed_support = gr.Image(label=\"Posed Support Image\",\n                                 type=\"pil\",\n                                 interactive=False).style(height=400,\n                                                          width=400)\n    with gr.Row():\n        query_img = gr.Image(label=\"Query Image\",\n                             type=\"pil\").style(height=400, width=400)\n    with gr.Row():\n        eval_btn = gr.Button(value=\"Evaluate\")\n    with gr.Row():\n        output_img = gr.Plot(label=\"Output Image\", height=400, width=400)\n    with gr.Row():\n        gr.Markdown(\"## Examples\")\n    with gr.Row():\n        gr.Examples(\n            examples=[\n                ['examples/dog2.png',\n                 'examples/dog2.png',\n                 'examples/dog1.png',\n                 {'kp_src': [(50, 58), (51, 78), (66, 57), (118, 79),\n                             (154, 79), (217, 74), (218, 103), (156, 104),\n                             (152, 151), (215, 162), (213, 191),\n                             (152, 174), (108, 171)],\n                  'skeleton': [(0, 1), (1, 2), (0, 2), (3, 4), (4, 5),\n                               (3, 7), (7, 6), (3, 12), (12, 8), (8, 9),\n                               (12, 11), (11, 10)], 'count': 0,\n                  'color_idx': 0, 'prev_pt': (174, 152),\n                  'prev_pt_idx': 11, 'prev_clicked': (207, 186),\n                  'original_support_image': None,\n                  }\n                 ],\n                ['examples/sofa1.jpg',\n                 'examples/sofa1.jpg',\n                 'examples/sofa2.jpg',\n                 {\n                     'kp_src': [(82, 28), (65, 30), (52, 26), (65, 50),\n                                (84, 52), (53, 54), (43, 52), (45, 71),\n                                (81, 69), (77, 39), (57, 43), (58, 64),\n                                (46, 42), (49, 65)],\n                     'skeleton': [(0, 1), (3, 1), (3, 4), (10, 9), (11, 8),\n                                  (1, 10), (10, 11), (11, 3), (1, 2), (7, 6),\n                                  (5, 13), (5, 3), (13, 11), (12, 10), (12, 2),\n                                  (6, 10), (7, 11)], 'count': 0,\n                     'color_idx': 23, 'prev_pt': (71, 45), 'prev_pt_idx': 7,\n                     'prev_clicked': (56, 63),\n                     'original_support_image': None,\n                 }],\n                ['examples/person1.jpeg',\n                 'examples/person1.jpeg',\n                 'examples/person2.jpeg',\n                 {\n                     'kp_src': [(121, 95), (122, 160), (154, 130), (184, 106),\n                                (181, 153)],\n                     'skeleton': [(0, 1), (1, 2), (0, 2), (2, 3), (2, 4),\n                                  (4, 3)], 'count': 0, 'color_idx': 6,\n                     'prev_pt': (153, 181), 'prev_pt_idx': 4,\n                     'prev_clicked': (181, 108),\n                     'original_support_image': None,\n                 }]\n            ],\n            inputs=[support_img, posed_support, query_img, state],\n            outputs=[support_img, posed_support, query_img, state],\n            fn=update_examples,\n            run_on_click=True,\n        )\n\n    support_img.select(get_select_coords,\n                       [support_img, posed_support, state],\n                       [support_img, posed_support, state])\n    support_img.upload(set_query,\n                       inputs=[support_img, state],\n                       outputs=[support_img, posed_support, state])\n    posed_support.select(get_limbs,\n                         [posed_support, state],\n                         [posed_support, state])\n    eval_btn.click(fn=process,\n                   inputs=[query_img, state],\n                   outputs=[output_img, state])\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='Pose Anything Demo')\n    parser.add_argument('--checkpoint',\n                        help='checkpoint path',\n                        default='https://github.com/orhir/PoseAnything'\n                                '/releases/download/1.0.0/demo_b.pth')\n    args = parser.parse_args()\n    checkpoint_path = args.checkpoint\n    demo.launch()\n"
  },
  {
    "path": "configs/1shot-swin/base_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/base_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/base_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/base_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/base_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/graph_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/graph_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/graph_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/graph_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shot-swin/graph_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/base_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/base_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/base_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/base_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/base_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/graph_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/graph_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/graph_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/graph_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/1shots/graph_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=16,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/base_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/base_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/base_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/base_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/base_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/graph_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/graph_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/graph_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/graph_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shot-swin/graph_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_small_1k_500k.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/base_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/base_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/base_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/base_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/base_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/graph_split1_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/graph_split2_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split2_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/graph_split3_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split3_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/graph_split4_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split4_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/5shots/graph_split5_config.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.2,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split5_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=5,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "configs/demo_b.py",
    "content": "log_level = 'INFO'\nload_from = None\nresume_from = None\ndist_params = dict(backend='nccl')\nworkflow = [('train', 1)]\ncheckpoint_config = dict(interval=20)\nevaluation = dict(\n    interval=25,\n    metric=['PCK', 'NME', 'AUC', 'EPE'],\n    key_indicator='PCK',\n    gpu_collect=True,\n    res_folder='')\noptimizer = dict(\n    type='Adam',\n    lr=1e-5,\n)\n\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=1000,\n    warmup_ratio=0.001,\n    step=[160, 180])\ntotal_epochs = 200\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        dict(type='TensorboardLoggerHook')\n    ])\n\nchannel_cfg = dict(\n    num_output_channels=1,\n    dataset_joints=1,\n    dataset_channel=[\n        [\n            0,\n        ],\n    ],\n    inference_channel=[\n        0,\n    ],\n    max_kpt_num=100)\n\n# model settings\nmodel = dict(\n    type='PoseAnythingModel',\n    pretrained='swinv2_small',\n    encoder_config=dict(\n        type='SwinTransformerV2',\n        embed_dim=96,\n        depths=[2, 2, 18, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=16,\n        drop_path_rate=0.3,\n        img_size=256,\n        upsample=\"bilinear\"\n    ),\n    keypoint_head=dict(\n        type='PoseHead',\n        in_channels=768,\n        transformer=dict(\n            type='EncoderDecoder',\n            d_model=256,\n            nhead=8,\n            num_encoder_layers=3,\n            num_decoder_layers=3,\n            graph_decoder='pre',\n            dim_feedforward=768,\n            dropout=0.1,\n            similarity_proj_dim=256,\n            dynamic_proj_dim=128,\n            activation=\"relu\",\n            normalize_before=False,\n            return_intermediate_dec=True),\n        share_kpt_branch=False,\n        num_decoder_layer=3,\n        with_heatmap_loss=True,\n        \n        heatmap_loss_weight=2.0,\n        support_order_dropout=-1,\n        positional_encoding=dict(\n            type='SinePositionalEncoding', num_feats=128, normalize=True)),\n    # training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(\n        flip_test=False,\n        post_process='default',\n        shift_heatmap=True,\n        modulate_kernel=11))\n\ndata_cfg = dict(\n    image_size=[256, 256],\n    heatmap_size=[64, 64],\n    num_output_channels=channel_cfg['num_output_channels'],\n    num_joints=channel_cfg['dataset_joints'],\n    dataset_channel=channel_cfg['dataset_channel'],\n    inference_channel=channel_cfg['inference_channel'])\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='TopDownGetRandomScaleRotation', rot_factor=15,\n        scale_factor=0.15),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',\n            'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',\n        ]),\n]\n\nvalid_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='TopDownAffineFewShot'),\n    dict(type='ToTensor'),\n    dict(\n        type='NormalizeTensor',\n        mean=[0.485, 0.456, 0.406],\n        std=[0.229, 0.224, 0.225]),\n    dict(type='TopDownGenerateTargetFewShot', sigma=1),\n    dict(\n        type='Collect',\n        keys=['img', 'target', 'target_weight'],\n        meta_keys=[\n            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',\n            'flip_pairs', 'category_id',\n            'skeleton',\n        ]),\n]\n\ntest_pipeline = valid_pipeline\n\ndata_root = 'data/mp100'\ndata = dict(\n    samples_per_gpu=8,\n    workers_per_gpu=8,\n    train=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_train.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        pipeline=train_pipeline),\n    val=dict(\n        type='TransformerPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_val.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=100,\n        pipeline=valid_pipeline),\n    test=dict(\n        type='TestPoseDataset',\n        ann_file=f'{data_root}/annotations/mp100_split1_test.json',\n        img_prefix=f'{data_root}/images/',\n        # img_prefix=f'{data_root}',\n        data_cfg=data_cfg,\n        valid_class_ids=None,\n        max_kpt_num=channel_cfg['max_kpt_num'],\n        num_shots=1,\n        num_queries=15,\n        num_episodes=200,\n        pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],\n        pipeline=test_pipeline),\n)\nvis_backends = [\n    dict(type='LocalVisBackend'),\n    dict(type='TensorboardVisBackend'),\n]\nvisualizer = dict(\n    type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')\n\nshuffle_cfg = dict(interval=1)\n"
  },
  {
    "path": "demo.py",
    "content": "import argparse\nimport copy\nimport os\nimport pickle\nimport random\nimport cv2\nimport numpy as np\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.runner import load_checkpoint\nfrom mmpose.core import wrap_fp16_model\nfrom mmpose.models import build_posenet\nfrom torchvision import transforms\nfrom models import *\nimport torchvision.transforms.functional as F\n\nfrom tools.visualization import plot_results\n\nCOLORS = [\n    [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],\n    [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],\n    [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],\n    [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]]\n\nclass Resize_Pad:\n    def __init__(self, w=256, h=256):\n        self.w = w\n        self.h = h\n\n    def __call__(self, image):\n        _, w_1, h_1 = image.shape\n        ratio_1 = w_1 / h_1\n        # check if the original and final aspect ratios are the same within a margin\n        if round(ratio_1, 2) != 1:\n            # padding to preserve aspect ratio\n            if ratio_1 > 1:  # Make the image higher\n                hp = int(w_1 - h_1)\n                hp = hp // 2\n                image = F.pad(image, (hp, 0, hp, 0), 0, \"constant\")\n                return F.resize(image, [self.h, self.w])\n            else:\n                wp = int(h_1 - w_1)\n                wp = wp // 2\n                image = F.pad(image, (0, wp, 0, wp), 0, \"constant\")\n                return F.resize(image, [self.h, self.w])\n        else:\n            return F.resize(image, [self.h, self.w])\n\n\ndef transform_keypoints_to_pad_and_resize(keypoints, image_size):\n    trans_keypoints = keypoints.clone()\n    h, w = image_size[:2]\n    ratio_1 = w / h\n    if ratio_1 > 1:\n        # width is bigger than height - pad height\n        hp = int(w - h)\n        hp = hp // 2\n        trans_keypoints[:, 1] = keypoints[:, 1] + hp\n        trans_keypoints *= (256. / w)\n    else:\n        # height is bigger than width - pad width\n        wp = int(image_size[1] - image_size[0])\n        wp = wp // 2\n        trans_keypoints[:, 0] = keypoints[:, 0] + wp\n        trans_keypoints *= (256. / h)\n    return trans_keypoints\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Pose Anything Demo')\n    parser.add_argument('--support', help='Image file')\n    parser.add_argument('--query', help='Image file')\n    parser.add_argument('--config', default=None, help='test config file path')\n    parser.add_argument('--checkpoint', default=None, help='checkpoint file')\n    parser.add_argument('--outdir', default='output', help='checkpoint file')\n\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n             'the inference speed')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        default={},\n        help='override some settings in the used config, the key-value pair '\n             'in xxx=yyy format will be merged into config file. For example, '\n             \"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'\")\n    args = parser.parse_args()\n    return args\n\n\ndef merge_configs(cfg1, cfg2):\n    # Merge cfg2 into cfg1\n    # Overwrite cfg1 if repeated, ignore if value is None.\n    cfg1 = {} if cfg1 is None else cfg1.copy()\n    cfg2 = {} if cfg2 is None else cfg2\n    for k, v in cfg2.items():\n        if v:\n            cfg1[k] = v\n    return cfg1\n\n\ndef main():\n    random.seed(0)\n    np.random.seed(0)\n    torch.manual_seed(0)\n\n    args = parse_args()\n    cfg = Config.fromfile(args.config)\n\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    cfg.data.test.test_mode = True\n\n    os.makedirs(args.outdir, exist_ok=True)\n\n    # Load data\n    support_img = cv2.imread(args.support)\n    query_img = cv2.imread(args.query)\n    if support_img is None or query_img is None:\n        raise ValueError('Fail to read images')\n\n    preprocess = transforms.Compose([\n        transforms.ToTensor(),\n        Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)])\n\n    # frame = copy.deepcopy(support_img)\n    padded_support_img = preprocess(support_img).cpu().numpy().transpose(1, 2, 0) * 255\n    frame = copy.deepcopy(padded_support_img.astype(np.uint8).copy())\n    kp_src = []\n    skeleton = []\n    count = 0\n    prev_pt = None\n    prev_pt_idx = None\n    color_idx = 0\n\n    def selectKP(event, x, y, flags, param):\n        nonlocal kp_src, frame\n        # if we are in points selection mode, the mouse was clicked,\n        # list of  points with the (x, y) location of the click\n        # and draw the circle\n\n        if event == cv2.EVENT_LBUTTONDOWN:\n            kp_src.append((x, y))\n            cv2.circle(frame, (x, y), 2, (0, 0, 255), 1)\n            cv2.imshow(\"Source\", frame)\n\n        if event == cv2.EVENT_RBUTTONDOWN:\n            kp_src = []\n            frame = copy.deepcopy(support_img)\n            cv2.imshow(\"Source\", frame)\n\n    def draw_line(event, x, y, flags, param):\n        nonlocal skeleton, kp_src, frame, count, prev_pt, prev_pt_idx, marked_frame, color_idx\n        if event == cv2.EVENT_LBUTTONDOWN:\n            closest_point = min(kp_src, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2)\n            closest_point_index = kp_src.index(closest_point)\n            if color_idx < len(COLORS):\n                c = COLORS[color_idx]\n            else:\n                c = random.choices(range(256), k=3)\n            color = color_idx\n            cv2.circle(frame, closest_point, 2, c, 1)\n            if count == 0:\n                prev_pt = closest_point\n                prev_pt_idx = closest_point_index\n                count = count + 1\n                cv2.imshow(\"Source\", frame)\n            else:\n                cv2.line(frame, prev_pt, closest_point, c, 2)\n                cv2.imshow(\"Source\", frame)\n                count = 0\n                skeleton.append((prev_pt_idx, closest_point_index))\n                color_idx = color_idx + 1\n        elif event == cv2.EVENT_RBUTTONDOWN:\n            frame = copy.deepcopy(marked_frame)\n            cv2.imshow(\"Source\", frame)\n            count = 0\n            color_idx = 0\n            skeleton = []\n            prev_pt = None\n\n    cv2.namedWindow(\"Source\", cv2.WINDOW_NORMAL)\n    cv2.resizeWindow('Source', 800, 600)\n    cv2.setMouseCallback(\"Source\", selectKP)\n    cv2.imshow(\"Source\", frame)\n\n    # keep looping until points have been selected\n    print('Press any key when finished marking the points!! ')\n    while True:\n        if cv2.waitKey(1) > 0:\n            break\n\n    marked_frame = copy.deepcopy(frame)\n    cv2.setMouseCallback(\"Source\", draw_line)\n    print('Press any key when finished creating skeleton!!')\n    while True:\n        if cv2.waitKey(1) > 0:\n            break\n\n    cv2.destroyAllWindows()\n    kp_src = torch.tensor(kp_src).float()\n    preprocess = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n        Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)])\n\n    if len(skeleton) == 0:\n        skeleton = [(0, 0)]\n\n    support_img = preprocess(support_img).flip(0)[None]\n    query_img = preprocess(query_img).flip(0)[None]\n    # Create heatmap from keypoints\n    genHeatMap = TopDownGenerateTargetFewShot()\n    data_cfg = cfg.data_cfg\n    data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size])\n    data_cfg['joint_weights'] = None\n    data_cfg['use_different_joint_weights'] = False\n    kp_src_3d = torch.concatenate((kp_src, torch.zeros(kp_src.shape[0], 1)), dim=-1)\n    kp_src_3d_weight = torch.concatenate((torch.ones_like(kp_src), torch.zeros(kp_src.shape[0], 1)), dim=-1)\n    target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1)\n    target_s = torch.tensor(target_s).float()[None]\n    target_weight_s = torch.tensor(target_weight_s).float()[None]\n\n    data = {\n        'img_s': [support_img],\n        'img_q': query_img,\n        'target_s': [target_s],\n        'target_weight_s': [target_weight_s],\n        'target_q': None,\n        'target_weight_q': None,\n        'return_loss': False,\n        'img_metas': [{'sample_skeleton': [skeleton],\n                       'query_skeleton': skeleton,\n                       'sample_joints_3d': [kp_src_3d],\n                       'query_joints_3d': kp_src_3d,\n                       'sample_center': [kp_src.mean(dim=0)],\n                       'query_center': kp_src.mean(dim=0),\n                       'sample_scale': [kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0]],\n                       'query_scale': kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0],\n                       'sample_rotation': [0],\n                       'query_rotation': 0,\n                       'sample_bbox_score': [1],\n                       'query_bbox_score': 1,\n                       'query_image_file': '',\n                       'sample_image_file': [''],\n                       }]\n    }\n\n    # Load model\n    model = build_posenet(cfg.model)\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    model.eval()\n\n    with torch.no_grad():\n        outputs = model(**data)\n\n    # visualize results\n    vis_s_weight = target_weight_s[0]\n    vis_q_weight = target_weight_s[0]\n    vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)\n    vis_q_image = query_img[0].detach().cpu().numpy().transpose(1, 2, 0)\n    support_kp = kp_src_3d\n\n    plot_results(vis_s_image,\n                 vis_q_image,\n                 support_kp,\n                 vis_s_weight,\n                 None,\n                 vis_q_weight,\n                 skeleton,\n                 None,\n                 torch.tensor(outputs['points']).squeeze(0),\n                 out_dir=args.outdir)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "ARG PYTORCH=\"2.0.1\"\nARG CUDA=\"11.7\"\nARG CUDNN=\"8\"\n\nFROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel\n\nENV TORCH_CUDA_ARCH_LIST=\"6.0 6.1 7.0+PTX\"\nENV TORCH_NVCC_FLAGS=\"-Xfatbin -compress-all\"\nENV CMAKE_PREFIX_PATH=\"$(dirname $(which conda))/../\"\nENV TZ=Asia/Kolkata DEBIAN_FRONTEND=noninteractive\n# To fix GPG key error when running apt-get update\nRUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub\nRUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub\n\nRUN apt-get update && apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx\\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Install xtcocotools\nRUN pip install cython\nRUN pip install xtcocotools\n# Install MMEngine and MMCV\nRUN pip install openmim\nRUN mim install mmengine\nRUN mim install \"mmpose==0.28.1\"\nRUN mim install \"mmcv-full==1.5.3\"\nRUN pip install -U torchmetrics timm\nRUN pip install numpy scipy --upgrade\nRUN pip install future tensorboard\n\nWORKDIR PoseAnything\n\nCOPY models PoseAnything/models\nCOPY configs PoseAnything/configs\nCOPY pretrained PoseAnything/pretrained\nCOPY requirements.txt PoseAnything/\nCOPY tools PoseAnything/tools\nCOPY setup.cfg PoseAnything/\nCOPY setup.py PoseAnything/\nCOPY test.py PoseAnything/\nCOPY train.py PoseAnything/\nCOPY README.md PoseAnything/\n\nRUN mkdir -p PoseAnything/data/mp100\nWORKDIR PoseAnything\n\n# Install MMPose\nRUN conda clean --all\nENV FORCE_CUDA=\"1\"\nRUN python setup.py develop"
  },
  {
    "path": "models/VERSION",
    "content": "0.2.0\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .core import *  # noqa\nfrom .datasets import *  # noqa\nfrom .models import *  # noqa\n"
  },
  {
    "path": "models/apis/__init__.py",
    "content": "from .train import train_model\n\n__all__ = [\n    'train_model'\n]\n"
  },
  {
    "path": "models/apis/train.py",
    "content": "import os\n\nimport torch\nfrom models.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,\n                         build_optimizer)\nfrom mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook\nfrom mmpose.datasets import build_dataloader\nfrom mmpose.utils import get_root_logger\n\n\ndef train_model(model,\n                dataset,\n                val_dataset,\n                cfg,\n                distributed=False,\n                validate=False,\n                timestamp=None,\n                meta=None):\n    \"\"\"Train model entry function.\n\n    Args:\n        model (nn.Module): The model to be trained.\n        dataset (Dataset): Train dataset.\n        cfg (dict): The config dict for training.\n        distributed (bool): Whether to use distributed training.\n            Default: False.\n        validate (bool): Whether to do evaluation. Default: False.\n        timestamp (str | None): Local time for runner. Default: None.\n        meta (dict | None): Meta dict to record some important information.\n            Default: None\n    \"\"\"\n    logger = get_root_logger(cfg.log_level)\n\n    # prepare data loaders\n    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]\n    dataloader_setting = dict(\n        samples_per_gpu=cfg.data.get('samples_per_gpu', {}),\n        workers_per_gpu=cfg.data.get('workers_per_gpu', {}),\n        # cfg.gpus will be ignored if distributed\n        num_gpus=len(cfg.gpu_ids),\n        dist=distributed,\n        seed=cfg.seed,\n        pin_memory=False,\n    )\n    dataloader_setting = dict(dataloader_setting,\n                              **cfg.data.get('train_dataloader', {}))\n\n    data_loaders = [\n        build_dataloader(ds, **dataloader_setting) for ds in dataset\n    ]\n\n    # put model on gpus\n    if distributed:\n        find_unused_parameters = cfg.get('find_unused_parameters',\n                                         False)  # NOTE: True has been modified to False for faster training.\n        # Sets the `find_unused_parameters` parameter in\n        # torch.nn.parallel.DistributedDataParallel\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False,\n            find_unused_parameters=find_unused_parameters)\n    else:\n        model = MMDataParallel(\n            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)\n\n    # build runner\n    optimizer = build_optimizer(model, cfg.optimizer)\n    runner = EpochBasedRunner(\n        model,\n        optimizer=optimizer,\n        work_dir=cfg.work_dir,\n        logger=logger,\n        meta=meta)\n    # an ugly workaround to make .log and .log.json filenames the same\n    runner.timestamp = timestamp\n\n    # fp16 setting\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        optimizer_config = Fp16OptimizerHook(\n            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)\n    elif distributed and 'type' not in cfg.optimizer_config:\n        optimizer_config = OptimizerHook(**cfg.optimizer_config)\n    else:\n        optimizer_config = cfg.optimizer_config\n\n    # register hooks\n    runner.register_training_hooks(cfg.lr_config, optimizer_config,\n                                   cfg.checkpoint_config, cfg.log_config,\n                                   cfg.get('momentum_config', None))\n    if distributed:\n        runner.register_hook(DistSamplerSeedHook())\n\n    shuffle_cfg = cfg.get('shuffle_cfg', None)\n    if shuffle_cfg is not None:\n        for data_loader in data_loaders:\n            runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg))\n\n    # register eval hooks\n    if validate:\n        eval_cfg = cfg.get('evaluation', {})\n        eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder'])\n        dataloader_setting = dict(\n            # samples_per_gpu=cfg.data.get('samples_per_gpu', {}),\n            samples_per_gpu=1,\n            workers_per_gpu=cfg.data.get('workers_per_gpu', {}),\n            # cfg.gpus will be ignored if distributed\n            num_gpus=len(cfg.gpu_ids),\n            dist=distributed,\n            shuffle=False,\n            pin_memory=False,\n        )\n        dataloader_setting = dict(dataloader_setting,\n                                  **cfg.data.get('val_dataloader', {}))\n        val_dataloader = build_dataloader(val_dataset, **dataloader_setting)\n        eval_hook = DistEvalHook if distributed else EvalHook\n        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))\n\n    if cfg.resume_from:\n        runner.resume(cfg.resume_from)\n    elif cfg.load_from:\n        runner.load_checkpoint(cfg.load_from)\n    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)\n"
  },
  {
    "path": "models/core/__init__.py",
    "content": "\n"
  },
  {
    "path": "models/core/custom_hooks/shuffle_hooks.py",
    "content": "from mmcv.runner import Hook\nfrom mmpose.utils import get_root_logger\nfrom torch.utils.data import DataLoader\n\n\nclass ShufflePairedSamplesHook(Hook):\n    \"\"\"Non-Distributed ShufflePairedSamples.\n    After each training epoch, run FewShotKeypointDataset.random_paired_samples()\n    \"\"\"\n\n    def __init__(self,\n                 dataloader,\n                 interval=1):\n        if not isinstance(dataloader, DataLoader):\n            raise TypeError(f'dataloader must be a pytorch DataLoader, '\n                            f'but got {type(dataloader)}')\n\n        self.dataloader = dataloader\n        self.interval = interval\n        self.logger = get_root_logger()\n\n    def after_train_epoch(self, runner):\n        \"\"\"Called after every training epoch to evaluate the results.\"\"\"\n        if not self.every_n_epochs(runner, self.interval):\n            return\n        # self.logger.info(\"Run random_paired_samples()\")\n        # self.logger.info(f\"Before: {self.dataloader.dataset.paired_samples[0]}\")\n        self.dataloader.dataset.random_paired_samples()\n        # self.logger.info(f\"After: {self.dataloader.dataset.paired_samples[0]}\")\n"
  },
  {
    "path": "models/datasets/__init__.py",
    "content": "from .builder import *  # noqa\nfrom .datasets import *  # noqa\nfrom .pipelines import *  # noqa\n"
  },
  {
    "path": "models/datasets/builder.py",
    "content": "from mmcv.utils import build_from_cfg\nfrom mmpose.datasets.builder import DATASETS\nfrom mmpose.datasets.dataset_wrappers import RepeatDataset\nfrom torch.utils.data.dataset import ConcatDataset\n\n\ndef _concat_cfg(cfg):\n    replace = ['ann_file', 'img_prefix']\n    channels = ['num_joints', 'dataset_channel']\n    concat_cfg = []\n    for i in range(len(cfg['type'])):\n        cfg_tmp = cfg.deepcopy()\n        cfg_tmp['type'] = cfg['type'][i]\n        for item in replace:\n            assert item in cfg_tmp\n            assert len(cfg['type']) == len(cfg[item]), (cfg[item])\n            cfg_tmp[item] = cfg[item][i]\n        for item in channels:\n            assert item in cfg_tmp['data_cfg']\n            assert len(cfg['type']) == len(cfg['data_cfg'][item])\n            cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]\n        concat_cfg.append(cfg_tmp)\n    return concat_cfg\n\n\ndef _check_vaild(cfg):\n    replace = ['num_joints', 'dataset_channel']\n    if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):\n        for item in replace:\n            cfg['data_cfg'][item] = cfg['data_cfg'][item][0]\n    return cfg\n\n\ndef build_dataset(cfg, default_args=None):\n    \"\"\"Build a dataset from config dict.\n\n    Args:\n        cfg (dict): Config dict. It should at least contain the key \"type\".\n        default_args (dict, optional): Default initialization arguments.\n            Default: None.\n\n    Returns:\n        Dataset: The constructed dataset.\n    \"\"\"\n    if isinstance(cfg['type'], (list, tuple)):  # In training, type=TransformerPoseDataset\n        dataset = ConcatDataset(\n            [build_dataset(c, default_args) for c in _concat_cfg(cfg)])\n    elif cfg['type'] == 'RepeatDataset':\n        dataset = RepeatDataset(\n            build_dataset(cfg['dataset'], default_args), cfg['times'])\n    else:\n        cfg = _check_vaild(cfg)\n        dataset = build_from_cfg(cfg, DATASETS, default_args)\n    return dataset\n"
  },
  {
    "path": "models/datasets/datasets/__init__.py",
    "content": "from .mp100 import (FewShotKeypointDataset, FewShotBaseDataset,\n                    TransformerBaseDataset, TransformerPoseDataset)\n\n__all__ = ['FewShotBaseDataset', 'FewShotKeypointDataset',\n           'TransformerBaseDataset', 'TransformerPoseDataset']\n"
  },
  {
    "path": "models/datasets/datasets/mp100/__init__.py",
    "content": "from .fewshot_base_dataset import FewShotBaseDataset\nfrom .fewshot_dataset import FewShotKeypointDataset\nfrom .test_base_dataset import TestBaseDataset\nfrom .test_dataset import TestPoseDataset\nfrom .transformer_base_dataset import TransformerBaseDataset\nfrom .transformer_dataset import TransformerPoseDataset\n\n__all__ = [\n    'FewShotKeypointDataset', 'FewShotBaseDataset',\n    'TransformerPoseDataset', 'TransformerBaseDataset',\n    'TestBaseDataset', 'TestPoseDataset'\n]\n"
  },
  {
    "path": "models/datasets/datasets/mp100/fewshot_base_dataset.py",
    "content": "import copy\nfrom abc import ABCMeta, abstractmethod\n\nimport json_tricks as json\nimport numpy as np\nfrom mmcv.parallel import DataContainer as DC\nfrom mmpose.core.evaluation.top_down_eval import (keypoint_pck_accuracy)\nfrom mmpose.datasets import DATASETS\nfrom mmpose.datasets.pipelines import Compose\nfrom torch.utils.data import Dataset\n\n\n@DATASETS.register_module()\nclass FewShotBaseDataset(Dataset, metaclass=ABCMeta):\n\n    def __init__(self,\n                 ann_file,\n                 img_prefix,\n                 data_cfg,\n                 pipeline,\n                 test_mode=False):\n        self.image_info = {}\n        self.ann_info = {}\n\n        self.annotations_path = ann_file\n        if not img_prefix.endswith('/'):\n            img_prefix = img_prefix + '/'\n        self.img_prefix = img_prefix\n        self.pipeline = pipeline\n        self.test_mode = test_mode\n\n        self.ann_info['image_size'] = np.array(data_cfg['image_size'])\n        self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])\n        self.ann_info['num_joints'] = data_cfg['num_joints']\n\n        self.ann_info['flip_pairs'] = None\n\n        self.ann_info['inference_channel'] = data_cfg['inference_channel']\n        self.ann_info['num_output_channels'] = data_cfg['num_output_channels']\n        self.ann_info['dataset_channel'] = data_cfg['dataset_channel']\n\n        self.db = []\n        self.num_shots = 1\n        self.paired_samples = []\n        self.pipeline = Compose(self.pipeline)\n\n    @abstractmethod\n    def _get_db(self):\n        \"\"\"Load dataset.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _select_kpt(self, obj, kpt_id):\n        \"\"\"Select kpt.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):\n        \"\"\"Evaluate keypoint results.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def _write_keypoint_results(keypoints, res_file):\n        \"\"\"Write results into a json file.\"\"\"\n\n        with open(res_file, 'w') as f:\n            json.dump(keypoints, f, sort_keys=True, indent=4)\n\n    def _report_metric(self,\n                       res_file,\n                       metrics,\n                       pck_thr=0.2,\n                       pckh_thr=0.7,\n                       auc_nor=30):\n        \"\"\"Keypoint evaluation.\n\n        Args:\n            res_file (str): Json file stored prediction results.\n            metrics (str | list[str]): Metric to be performed.\n                Options: 'PCK', 'PCKh', 'AUC', 'EPE'.\n            pck_thr (float): PCK threshold, default as 0.2.\n            pckh_thr (float): PCKh threshold, default as 0.7.\n            auc_nor (float): AUC normalization factor, default as 30 pixel.\n\n        Returns:\n            List: Evaluation results for evaluation metric.\n        \"\"\"\n        info_str = []\n\n        with open(res_file, 'r') as fin:\n            preds = json.load(fin)\n        assert len(preds) == len(self.paired_samples)\n\n        outputs = []\n        gts = []\n        masks = []\n        threshold_bbox = []\n        threshold_head_box = []\n\n        for pred, pair in zip(preds, self.paired_samples):\n            item = self.db[pair[-1]]\n            outputs.append(np.array(pred['keypoints'])[:, :-1])\n            gts.append(np.array(item['joints_3d'])[:, :-1])\n\n            mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)\n            mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)\n            for id_s in pair[:-1]:\n                mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))\n            masks.append(np.bitwise_and(mask_query, mask_sample))\n\n            if 'PCK' in metrics:\n                bbox = np.array(item['bbox'])\n                bbox_thr = np.max(bbox[2:])\n                threshold_bbox.append(np.array([bbox_thr, bbox_thr]))\n            if 'PCKh' in metrics:\n                head_box_thr = item['head_size']\n                threshold_head_box.append(\n                    np.array([head_box_thr, head_box_thr]))\n\n        if 'PCK' in metrics:\n            pck_avg = []\n            for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):\n                _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt, 0),\n                                                  np.expand_dims(mask, 0), pck_thr, np.expand_dims(thr_bbox, 0))\n                pck_avg.append(pck)\n            info_str.append(('PCK', np.mean(pck_avg)))\n\n        return info_str\n\n    def _merge_obj(self, Xs_list, Xq, idx):\n        \"\"\" merge Xs_list and Xq.\n\n        :param Xs_list: N-shot samples X\n        :param Xq: query X\n        :param idx: id of paired_samples\n        :return: Xall\n        \"\"\"\n        Xall = dict()\n        Xall['img_s'] = [Xs['img'] for Xs in Xs_list]\n        Xall['target_s'] = [Xs['target'] for Xs in Xs_list]\n        Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]\n        xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]\n\n        Xall['img_q'] = Xq['img']\n        Xall['target_q'] = Xq['target']\n        Xall['target_weight_q'] = Xq['target_weight']\n        xq_img_metas = Xq['img_metas'].data\n\n        img_metas = dict()\n        for key in xq_img_metas.keys():\n            img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]\n            img_metas['query_' + key] = xq_img_metas[key]\n        img_metas['bbox_id'] = idx\n\n        Xall['img_metas'] = DC(img_metas, cpu_only=True)\n\n        return Xall\n\n    def __len__(self):\n        \"\"\"Get the size of the dataset.\"\"\"\n        return len(self.paired_samples)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the sample given index.\"\"\"\n\n        pair_ids = self.paired_samples[idx]\n        assert len(pair_ids) == self.num_shots + 1\n        sample_id_list = pair_ids[:self.num_shots]\n        query_id = pair_ids[-1]\n\n        sample_obj_list = []\n        for sample_id in sample_id_list:\n            sample_obj = copy.deepcopy(self.db[sample_id])\n            sample_obj['ann_info'] = copy.deepcopy(self.ann_info)\n            sample_obj_list.append(sample_obj)\n\n        query_obj = copy.deepcopy(self.db[query_id])\n        query_obj['ann_info'] = copy.deepcopy(self.ann_info)\n\n        if not self.test_mode:\n            # randomly select \"one\" keypoint\n            sample_valid = (sample_obj_list[0]['joints_3d_visible'][:, 0] > 0)\n            for sample_obj in sample_obj_list:\n                sample_valid = sample_valid & (sample_obj['joints_3d_visible'][:, 0] > 0)\n            query_valid = (query_obj['joints_3d_visible'][:, 0] > 0)\n\n            valid_s = np.where(sample_valid)[0]\n            valid_q = np.where(query_valid)[0]\n            valid_sq = np.where(sample_valid & query_valid)[0]\n            if len(valid_sq) > 0:\n                kpt_id = np.random.choice(valid_sq)\n            elif len(valid_s) > 0:\n                kpt_id = np.random.choice(valid_s)\n            elif len(valid_q) > 0:\n                kpt_id = np.random.choice(valid_q)\n            else:\n                kpt_id = np.random.choice(np.array(range(len(query_valid))))\n\n            for i in range(self.num_shots):\n                sample_obj_list[i] = self._select_kpt(sample_obj_list[i], kpt_id)\n            query_obj = self._select_kpt(query_obj, kpt_id)\n\n        # when test, all keypoints will be preserved.\n\n        Xs_list = []\n        for sample_obj in sample_obj_list:\n            Xs = self.pipeline(sample_obj)\n            Xs_list.append(Xs)\n        Xq = self.pipeline(query_obj)\n\n        Xall = self._merge_obj(Xs_list, Xq, idx)\n        Xall['skeleton'] = self.db[query_id]['skeleton']\n\n        return Xall\n\n    def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):\n        \"\"\"sort kpts and remove the repeated ones.\"\"\"\n        kpts = sorted(kpts, key=lambda x: x[key])\n        num = len(kpts)\n        for i in range(num - 1, 0, -1):\n            if kpts[i][key] == kpts[i - 1][key]:\n                del kpts[i]\n\n        return kpts\n"
  },
  {
    "path": "models/datasets/datasets/mp100/fewshot_dataset.py",
    "content": "import os\nimport random\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom mmpose.datasets import DATASETS\nfrom xtcocotools.coco import COCO\n\nfrom .fewshot_base_dataset import FewShotBaseDataset\n\n\n@DATASETS.register_module()\nclass FewShotKeypointDataset(FewShotBaseDataset):\n\n    def __init__(self,\n                 ann_file,\n                 img_prefix,\n                 data_cfg,\n                 pipeline,\n                 valid_class_ids,\n                 num_shots=1,\n                 num_queries=100,\n                 num_episodes=1,\n                 test_mode=False):\n        super().__init__(\n            ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)\n\n        self.ann_info['flip_pairs'] = []\n\n        self.ann_info['upper_body_ids'] = []\n        self.ann_info['lower_body_ids'] = []\n\n        self.ann_info['use_different_joint_weights'] = False\n        self.ann_info['joint_weights'] = np.array([1., ],\n                                                  dtype=np.float32).reshape((self.ann_info['num_joints'], 1))\n\n        self.coco = COCO(ann_file)\n\n        self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)\n        self.img_ids = self.coco.getImgIds()\n        self.classes = [\n            cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())\n        ]\n\n        self.num_classes = len(self.classes)\n        self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))\n        self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))\n\n        if valid_class_ids is not None:\n            self.valid_class_ids = valid_class_ids\n        else:\n            self.valid_class_ids = self.coco.getCatIds()\n        self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]\n\n        self.cats = self.coco.cats\n\n        # Also update self.cat2obj\n        self.db = self._get_db()\n\n        self.num_shots = num_shots\n\n        if not test_mode:\n            # Update every training epoch\n            self.random_paired_samples()\n        else:\n            self.num_queries = num_queries\n            self.num_episodes = num_episodes\n            self.make_paired_samples()\n\n    def random_paired_samples(self):\n        num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]\n\n        # balance the dataset\n        max_num_data = max(num_datas)\n\n        all_samples = []\n        for cls in self.valid_class_ids:\n            for i in range(max_num_data):\n                shot = random.sample(self.cat2obj[cls], self.num_shots + 1)\n                all_samples.append(shot)\n\n        self.paired_samples = np.array(all_samples)\n        np.random.shuffle(self.paired_samples)\n\n    def make_paired_samples(self):\n        random.seed(1)\n        np.random.seed(0)\n\n        all_samples = []\n        for cls in self.valid_class_ids:\n            for _ in range(self.num_episodes):\n                shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)\n                sample_ids = shots[:self.num_shots]\n                query_ids = shots[self.num_shots:]\n                for query_id in query_ids:\n                    all_samples.append(sample_ids + [query_id])\n\n        self.paired_samples = np.array(all_samples)\n\n    def _select_kpt(self, obj, kpt_id):\n        obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id + 1]\n        obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id + 1]\n        obj['kpt_id'] = kpt_id\n\n        return obj\n\n    @staticmethod\n    def _get_mapping_id_name(imgs):\n        \"\"\"\n        Args:\n            imgs (dict): dict of image info.\n\n        Returns:\n            tuple: Image name & id mapping dicts.\n\n            - id2name (dict): Mapping image id to name.\n            - name2id (dict): Mapping image name to id.\n        \"\"\"\n        id2name = {}\n        name2id = {}\n        for image_id, image in imgs.items():\n            file_name = image['file_name']\n            id2name[image_id] = file_name\n            name2id[file_name] = image_id\n\n        return id2name, name2id\n\n    def _get_db(self):\n        \"\"\"Ground truth bbox and keypoints.\"\"\"\n        self.obj_id = 0\n\n        self.cat2obj = {}\n        for i in self.coco.getCatIds():\n            self.cat2obj.update({i: []})\n\n        gt_db = []\n        for img_id in self.img_ids:\n            gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))\n        return gt_db\n\n    def _load_coco_keypoint_annotation_kernel(self, img_id):\n        \"\"\"load annotation from COCOAPI.\n\n        Note:\n            bbox:[x1, y1, w, h]\n        Args:\n            img_id: coco image id\n        Returns:\n            dict: db entry\n        \"\"\"\n        img_ann = self.coco.loadImgs(img_id)[0]\n        width = img_ann['width']\n        height = img_ann['height']\n\n        ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)\n        objs = self.coco.loadAnns(ann_ids)\n\n        # sanitize bboxes\n        valid_objs = []\n        for obj in objs:\n            if 'bbox' not in obj:\n                continue\n            x, y, w, h = obj['bbox']\n            x1 = max(0, x)\n            y1 = max(0, y)\n            x2 = min(width - 1, x1 + max(0, w - 1))\n            y2 = min(height - 1, y1 + max(0, h - 1))\n            if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:\n                obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]\n                valid_objs.append(obj)\n        objs = valid_objs\n\n        bbox_id = 0\n        rec = []\n        for obj in objs:\n            if 'keypoints' not in obj:\n                continue\n            if max(obj['keypoints']) == 0:\n                continue\n            if 'num_keypoints' in obj and obj['num_keypoints'] == 0:\n                continue\n\n            category_id = obj['category_id']\n            # the number of keypoint for this specific category\n            cat_kpt_num = int(len(obj['keypoints']) / 3)\n\n            joints_3d = np.zeros((cat_kpt_num, 3), dtype=np.float32)\n            joints_3d_visible = np.zeros((cat_kpt_num, 3), dtype=np.float32)\n\n            keypoints = np.array(obj['keypoints']).reshape(-1, 3)\n            joints_3d[:, :2] = keypoints[:, :2]\n            joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])\n\n            center, scale = self._xywh2cs(*obj['clean_bbox'][:4])\n\n            image_file = os.path.join(self.img_prefix, self.id2name[img_id])\n\n            self.cat2obj[category_id].append(self.obj_id)\n\n            rec.append({\n                'image_file': image_file,\n                'center': center,\n                'scale': scale,\n                'rotation': 0,\n                'bbox': obj['clean_bbox'][:4],\n                'bbox_score': 1,\n                'joints_3d': joints_3d,\n                'joints_3d_visible': joints_3d_visible,\n                'category_id': category_id,\n                'cat_kpt_num': cat_kpt_num,\n                'bbox_id': self.obj_id,\n                'skeleton': self.coco.cats[obj['category_id']]['skeleton'],\n            })\n            bbox_id = bbox_id + 1\n            self.obj_id += 1\n\n        return rec\n\n    def _xywh2cs(self, x, y, w, h):\n        \"\"\"This encodes bbox(x,y,w,w) into (center, scale)\n\n        Args:\n            x, y, w, h\n\n        Returns:\n            tuple: A tuple containing center and scale.\n\n            - center (np.ndarray[float32](2,)): center of the bbox (x, y).\n            - scale (np.ndarray[float32](2,)): scale of the bbox w & h.\n        \"\"\"\n        aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]\n        center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)\n        #\n        # if (not self.test_mode) and np.random.rand() < 0.3:\n        #     center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]\n\n        if w > aspect_ratio * h:\n            h = w * 1.0 / aspect_ratio\n        elif w < aspect_ratio * h:\n            w = h * aspect_ratio\n\n        # pixel std is 200.0\n        scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)\n        # padding to include proper amount of context\n        scale = scale * 1.25\n\n        return center, scale\n\n    def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):\n        \"\"\"Evaluate interhand2d keypoint results. The pose prediction results\n        will be saved in `${res_folder}/result_keypoints.json`.\n\n        Note:\n            batch_size: N\n            num_keypoints: K\n            heatmap height: H\n            heatmap width: W\n\n        Args:\n            outputs (list(preds, boxes, image_path, output_heatmap))\n                :preds (np.ndarray[N,K,3]): The first two dimensions are\n                    coordinates, score is the third dimension of the array.\n                :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]\n                    , scale[1],area, score]\n                :image_paths (list[str]): For example, ['C', 'a', 'p', 't',\n                    'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',\n                    'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',\n                    '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',\n                    'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',\n                    'j', 'p', 'g']\n                :output_heatmap (np.ndarray[N, K, H, W]): model outpus.\n\n            res_folder (str): Path of directory to save the results.\n            metric (str | list[str]): Metric to be performed.\n                Options: 'PCK', 'AUC', 'EPE'.\n\n        Returns:\n            dict: Evaluation results for evaluation metric.\n        \"\"\"\n        metrics = metric if isinstance(metric, list) else [metric]\n        allowed_metrics = ['PCK', 'AUC', 'EPE']\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported')\n\n        res_file = os.path.join(res_folder, 'result_keypoints.json')\n\n        kpts = []\n        for output in outputs:\n            preds = output['preds']\n            boxes = output['boxes']\n            image_paths = output['image_paths']\n            bbox_ids = output['bbox_ids']\n\n            batch_size = len(image_paths)\n            for i in range(batch_size):\n                image_id = self.name2id[image_paths[i][len(self.img_prefix):]]\n\n                kpts.append({\n                    'keypoints': preds[i].tolist(),\n                    'center': boxes[i][0:2].tolist(),\n                    'scale': boxes[i][2:4].tolist(),\n                    'area': float(boxes[i][4]),\n                    'score': float(boxes[i][5]),\n                    'image_id': image_id,\n                    'bbox_id': bbox_ids[i]\n                })\n        kpts = self._sort_and_unique_bboxes(kpts)\n\n        self._write_keypoint_results(kpts, res_file)\n        info_str = self._report_metric(res_file, metrics)\n        name_value = OrderedDict(info_str)\n\n        return name_value\n"
  },
  {
    "path": "models/datasets/datasets/mp100/test_base_dataset.py",
    "content": "import copy\nfrom abc import ABCMeta, abstractmethod\n\nimport json_tricks as json\nimport numpy as np\nfrom mmcv.parallel import DataContainer as DC\nfrom mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,\n                                                  keypoint_pck_accuracy)\nfrom mmpose.datasets import DATASETS\nfrom mmpose.datasets.pipelines import Compose\nfrom torch.utils.data import Dataset\n\n\n@DATASETS.register_module()\nclass TestBaseDataset(Dataset, metaclass=ABCMeta):\n\n    def __init__(self,\n                 ann_file,\n                 img_prefix,\n                 data_cfg,\n                 pipeline,\n                 test_mode=True,\n                 PCK_threshold_list=[0.05, 0.1, 0.15, 0.2, 0.25]):\n        self.image_info = {}\n        self.ann_info = {}\n\n        self.annotations_path = ann_file\n        if not img_prefix.endswith('/'):\n            img_prefix = img_prefix + '/'\n        self.img_prefix = img_prefix\n        self.pipeline = pipeline\n        self.test_mode = test_mode\n        self.PCK_threshold_list = PCK_threshold_list\n\n        self.ann_info['image_size'] = np.array(data_cfg['image_size'])\n        self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])\n        self.ann_info['num_joints'] = data_cfg['num_joints']\n\n        self.ann_info['flip_pairs'] = None\n\n        self.ann_info['inference_channel'] = data_cfg['inference_channel']\n        self.ann_info['num_output_channels'] = data_cfg['num_output_channels']\n        self.ann_info['dataset_channel'] = data_cfg['dataset_channel']\n\n        self.db = []\n        self.num_shots = 1\n        self.paired_samples = []\n        self.pipeline = Compose(self.pipeline)\n\n    @abstractmethod\n    def _get_db(self):\n        \"\"\"Load dataset.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _select_kpt(self, obj, kpt_id):\n        \"\"\"Select kpt.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):\n        \"\"\"Evaluate keypoint results.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def _write_keypoint_results(keypoints, res_file):\n        \"\"\"Write results into a json file.\"\"\"\n\n        with open(res_file, 'w') as f:\n            json.dump(keypoints, f, sort_keys=True, indent=4)\n\n    def _report_metric(self,\n                       res_file,\n                       metrics):\n        \"\"\"Keypoint evaluation.\n\n        Args:\n            res_file (str): Json file stored prediction results.\n            metrics (str | list[str]): Metric to be performed.\n                Options: 'PCK', 'PCKh', 'AUC', 'EPE'.\n            pck_thr (float): PCK threshold, default as 0.2.\n            pckh_thr (float): PCKh threshold, default as 0.7.\n            auc_nor (float): AUC normalization factor, default as 30 pixel.\n\n        Returns:\n            List: Evaluation results for evaluation metric.\n        \"\"\"\n        info_str = []\n\n        with open(res_file, 'r') as fin:\n            preds = json.load(fin)\n        assert len(preds) == len(self.paired_samples)\n\n        outputs = []\n        gts = []\n        masks = []\n        threshold_bbox = []\n        threshold_head_box = []\n\n        for pred, pair in zip(preds, self.paired_samples):\n            item = self.db[pair[-1]]\n            outputs.append(np.array(pred['keypoints'])[:, :-1])\n            gts.append(np.array(item['joints_3d'])[:, :-1])\n\n            mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)\n            mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)\n            for id_s in pair[:-1]:\n                mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))\n            masks.append(np.bitwise_and(mask_query, mask_sample))\n\n            if 'PCK' in metrics or 'NME' in metrics or 'AUC' in metrics:\n                bbox = np.array(item['bbox'])\n                bbox_thr = np.max(bbox[2:])\n                threshold_bbox.append(np.array([bbox_thr, bbox_thr]))\n            if 'PCKh' in metrics:\n                head_box_thr = item['head_size']\n                threshold_head_box.append(\n                    np.array([head_box_thr, head_box_thr]))\n\n        if 'PCK' in metrics:\n            pck_results = dict()\n            for pck_thr in self.PCK_threshold_list:\n                pck_results[pck_thr] = []\n\n            for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):\n                for pck_thr in self.PCK_threshold_list:\n                    _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt, 0),\n                                                      np.expand_dims(mask, 0), pck_thr, np.expand_dims(thr_bbox, 0))\n                    pck_results[pck_thr].append(pck)\n\n            mPCK = 0\n            for pck_thr in self.PCK_threshold_list:\n                info_str.append(['PCK@' + str(pck_thr), np.mean(pck_results[pck_thr])])\n                mPCK += np.mean(pck_results[pck_thr])\n            info_str.append(['mPCK', mPCK / len(self.PCK_threshold_list)])\n\n        if 'NME' in metrics:\n            nme_results = []\n            for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):\n                nme = keypoint_nme(np.expand_dims(output, 0), np.expand_dims(gt, 0), np.expand_dims(mask, 0),\n                                   np.expand_dims(thr_bbox, 0))\n                nme_results.append(nme)\n            info_str.append(['NME', np.mean(nme_results)])\n\n        if 'AUC' in metrics:\n            auc_results = []\n            for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):\n                auc = keypoint_auc(np.expand_dims(output, 0), np.expand_dims(gt, 0), np.expand_dims(mask, 0),\n                                   thr_bbox[0])\n                auc_results.append(auc)\n            info_str.append(['AUC', np.mean(auc_results)])\n\n        if 'EPE' in metrics:\n            epe_results = []\n            for (output, gt, mask) in zip(outputs, gts, masks):\n                epe = keypoint_epe(np.expand_dims(output, 0), np.expand_dims(gt, 0), np.expand_dims(mask, 0))\n                epe_results.append(epe)\n            info_str.append(['EPE', np.mean(epe_results)])\n        return info_str\n\n    def _merge_obj(self, Xs_list, Xq, idx):\n        \"\"\" merge Xs_list and Xq.\n\n        :param Xs_list: N-shot samples X\n        :param Xq: query X\n        :param idx: id of paired_samples\n        :return: Xall\n        \"\"\"\n        Xall = dict()\n        Xall['img_s'] = [Xs['img'] for Xs in Xs_list]\n        Xall['target_s'] = [Xs['target'] for Xs in Xs_list]\n        Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]\n        xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]\n\n        Xall['img_q'] = Xq['img']\n        Xall['target_q'] = Xq['target']\n        Xall['target_weight_q'] = Xq['target_weight']\n        xq_img_metas = Xq['img_metas'].data\n\n        img_metas = dict()\n        for key in xq_img_metas.keys():\n            img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]\n            img_metas['query_' + key] = xq_img_metas[key]\n        img_metas['bbox_id'] = idx\n\n        Xall['img_metas'] = DC(img_metas, cpu_only=True)\n\n        return Xall\n\n    def __len__(self):\n        \"\"\"Get the size of the dataset.\"\"\"\n        return len(self.paired_samples)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the sample given index.\"\"\"\n\n        pair_ids = self.paired_samples[idx]  # [supported id * shots, query id]\n        assert len(pair_ids) == self.num_shots + 1\n        sample_id_list = pair_ids[:self.num_shots]\n        query_id = pair_ids[-1]\n\n        sample_obj_list = []\n        for sample_id in sample_id_list:\n            sample_obj = copy.deepcopy(self.db[sample_id])\n            sample_obj['ann_info'] = copy.deepcopy(self.ann_info)\n            sample_obj_list.append(sample_obj)\n\n        query_obj = copy.deepcopy(self.db[query_id])\n        query_obj['ann_info'] = copy.deepcopy(self.ann_info)\n\n        Xs_list = []\n        for sample_obj in sample_obj_list:\n            Xs = self.pipeline(sample_obj)  # dict with ['img', 'target', 'target_weight', 'img_metas'],\n            Xs_list.append(Xs)  # Xs['target'] is of shape [100, map_h, map_w]\n        Xq = self.pipeline(query_obj)\n\n        Xall = self._merge_obj(Xs_list, Xq, idx)\n        Xall['skeleton'] = self.db[query_id]['skeleton']\n\n        return Xall\n\n    def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):\n        \"\"\"sort kpts and remove the repeated ones.\"\"\"\n        kpts = sorted(kpts, key=lambda x: x[key])\n        num = len(kpts)\n        for i in range(num - 1, 0, -1):\n            if kpts[i][key] == kpts[i - 1][key]:\n                del kpts[i]\n\n        return kpts\n"
  },
  {
    "path": "models/datasets/datasets/mp100/test_dataset.py",
    "content": "import os\nimport random\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom mmpose.datasets import DATASETS\nfrom xtcocotools.coco import COCO\n\nfrom .test_base_dataset import TestBaseDataset\n\n\n@DATASETS.register_module()\nclass TestPoseDataset(TestBaseDataset):\n\n    def __init__(self,\n                 ann_file,\n                 img_prefix,\n                 data_cfg,\n                 pipeline,\n                 valid_class_ids,\n                 max_kpt_num=None,\n                 num_shots=1,\n                 num_queries=100,\n                 num_episodes=1,\n                 pck_threshold_list=[0.05, 0.1, 0.15, 0.20, 0.25],\n                 test_mode=True):\n        super().__init__(\n            ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode, PCK_threshold_list=pck_threshold_list)\n\n        self.ann_info['flip_pairs'] = []\n\n        self.ann_info['upper_body_ids'] = []\n        self.ann_info['lower_body_ids'] = []\n\n        self.ann_info['use_different_joint_weights'] = False\n        self.ann_info['joint_weights'] = np.array([1., ],\n                                                  dtype=np.float32).reshape((self.ann_info['num_joints'], 1))\n\n        self.coco = COCO(ann_file)\n\n        self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)\n        self.img_ids = self.coco.getImgIds()\n        self.classes = [\n            cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())\n        ]\n\n        self.num_classes = len(self.classes)\n        self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))\n        self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))\n\n        if valid_class_ids is not None:  # None by default\n            self.valid_class_ids = valid_class_ids\n        else:\n            self.valid_class_ids = self.coco.getCatIds()\n        self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]\n\n        self.cats = self.coco.cats\n        self.max_kpt_num = max_kpt_num\n\n        # Also update self.cat2obj\n        self.db = self._get_db()\n\n        self.num_shots = num_shots\n\n        if not test_mode:\n            # Update every training epoch\n            self.random_paired_samples()\n        else:\n            self.num_queries = num_queries\n            self.num_episodes = num_episodes\n            self.make_paired_samples()\n\n    def random_paired_samples(self):\n        num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]\n\n        # balance the dataset\n        max_num_data = max(num_datas)\n\n        all_samples = []\n        for cls in self.valid_class_ids:\n            for i in range(max_num_data):\n                shot = random.sample(self.cat2obj[cls], self.num_shots + 1)\n                all_samples.append(shot)\n\n        self.paired_samples = np.array(all_samples)\n        np.random.shuffle(self.paired_samples)\n\n    def make_paired_samples(self):\n        random.seed(1)\n        np.random.seed(0)\n\n        all_samples = []\n        for cls in self.valid_class_ids:\n            for _ in range(self.num_episodes):\n                shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)\n                sample_ids = shots[:self.num_shots]\n                query_ids = shots[self.num_shots:]\n                for query_id in query_ids:\n                    all_samples.append(sample_ids + [query_id])\n\n        self.paired_samples = np.array(all_samples)\n\n    def _select_kpt(self, obj, kpt_id):\n        obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id + 1]\n        obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id + 1]\n        obj['kpt_id'] = kpt_id\n\n        return obj\n\n    @staticmethod\n    def _get_mapping_id_name(imgs):\n        \"\"\"\n        Args:\n            imgs (dict): dict of image info.\n\n        Returns:\n            tuple: Image name & id mapping dicts.\n\n            - id2name (dict): Mapping image id to name.\n            - name2id (dict): Mapping image name to id.\n        \"\"\"\n        id2name = {}\n        name2id = {}\n        for image_id, image in imgs.items():\n            file_name = image['file_name']\n            id2name[image_id] = file_name\n            name2id[file_name] = image_id\n\n        return id2name, name2id\n\n    def _get_db(self):\n        \"\"\"Ground truth bbox and keypoints.\"\"\"\n        self.obj_id = 0\n\n        self.cat2obj = {}\n        for i in self.coco.getCatIds():\n            self.cat2obj.update({i: []})\n\n        gt_db = []\n        for img_id in self.img_ids:\n            gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))\n        return gt_db\n\n    def _load_coco_keypoint_annotation_kernel(self, img_id):\n        \"\"\"load annotation from COCOAPI.\n\n        Note:\n            bbox:[x1, y1, w, h]\n        Args:\n            img_id: coco image id\n        Returns:\n            dict: db entry\n        \"\"\"\n        img_ann = self.coco.loadImgs(img_id)[0]\n        width = img_ann['width']\n        height = img_ann['height']\n\n        ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)\n        objs = self.coco.loadAnns(ann_ids)\n\n        # sanitize bboxes\n        valid_objs = []\n        for obj in objs:\n            if 'bbox' not in obj:\n                continue\n            x, y, w, h = obj['bbox']\n            x1 = max(0, x)\n            y1 = max(0, y)\n            x2 = min(width - 1, x1 + max(0, w - 1))\n            y2 = min(height - 1, y1 + max(0, h - 1))\n            if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:\n                obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]\n                valid_objs.append(obj)\n        objs = valid_objs\n\n        bbox_id = 0\n        rec = []\n        for obj in objs:\n            if 'keypoints' not in obj:\n                continue\n            if max(obj['keypoints']) == 0:\n                continue\n            if 'num_keypoints' in obj and obj['num_keypoints'] == 0:\n                continue\n\n            category_id = obj['category_id']\n            # the number of keypoint for this specific category\n            cat_kpt_num = int(len(obj['keypoints']) / 3)\n            if self.max_kpt_num is None:\n                kpt_num = cat_kpt_num\n            else:\n                kpt_num = self.max_kpt_num\n\n            joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)\n            joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)\n\n            keypoints = np.array(obj['keypoints']).reshape(-1, 3)\n            joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]\n            joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])\n\n            center, scale = self._xywh2cs(*obj['clean_bbox'][:4])\n\n            image_file = os.path.join(self.img_prefix, self.id2name[img_id])\n\n            self.cat2obj[category_id].append(self.obj_id)\n\n            rec.append({\n                'image_file': image_file,\n                'center': center,\n                'scale': scale,\n                'rotation': 0,\n                'bbox': obj['clean_bbox'][:4],\n                'bbox_score': 1,\n                'joints_3d': joints_3d,\n                'joints_3d_visible': joints_3d_visible,\n                'category_id': category_id,\n                'cat_kpt_num': cat_kpt_num,\n                'bbox_id': self.obj_id,\n                'skeleton': self.coco.cats[obj['category_id']]['skeleton'],\n            })\n            bbox_id = bbox_id + 1\n            self.obj_id += 1\n\n        return rec\n\n    def _xywh2cs(self, x, y, w, h):\n        \"\"\"This encodes bbox(x,y,w,w) into (center, scale)\n\n        Args:\n            x, y, w, h\n\n        Returns:\n            tuple: A tuple containing center and scale.\n\n            - center (np.ndarray[float32](2,)): center of the bbox (x, y).\n            - scale (np.ndarray[float32](2,)): scale of the bbox w & h.\n        \"\"\"\n        aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]\n        center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)\n        #\n        # if (not self.test_mode) and np.random.rand() < 0.3:\n        #     center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]\n\n        if w > aspect_ratio * h:\n            h = w * 1.0 / aspect_ratio\n        elif w < aspect_ratio * h:\n            w = h * aspect_ratio\n\n        # pixel std is 200.0\n        scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)\n        # padding to include proper amount of context\n        scale = scale * 1.25\n\n        return center, scale\n\n    def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):\n        \"\"\"Evaluate interhand2d keypoint results. The pose prediction results\n        will be saved in `${res_folder}/result_keypoints.json`.\n\n        Note:\n            batch_size: N\n            num_keypoints: K\n            heatmap height: H\n            heatmap width: W\n\n        Args:\n            outputs (list(preds, boxes, image_path, output_heatmap))\n                :preds (np.ndarray[N,K,3]): The first two dimensions are\n                    coordinates, score is the third dimension of the array.\n                :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]\n                    , scale[1],area, score]\n                :image_paths (list[str]): For example, ['C', 'a', 'p', 't',\n                    'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',\n                    'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',\n                    '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',\n                    'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',\n                    'j', 'p', 'g']\n                :output_heatmap (np.ndarray[N, K, H, W]): model outpus.\n\n            res_folder (str): Path of directory to save the results.\n            metric (str | list[str]): Metric to be performed.\n                Options: 'PCK', 'AUC', 'EPE'.\n\n        Returns:\n            dict: Evaluation results for evaluation metric.\n        \"\"\"\n        metrics = metric if isinstance(metric, list) else [metric]\n        allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported')\n\n        res_file = os.path.join(res_folder, 'result_keypoints.json')\n\n        kpts = []\n        for output in outputs:\n            preds = output['preds']\n            boxes = output['boxes']\n            image_paths = output['image_paths']\n            bbox_ids = output['bbox_ids']\n\n            batch_size = len(image_paths)\n            for i in range(batch_size):\n                image_id = self.name2id[image_paths[i][len(self.img_prefix):]]\n\n                kpts.append({\n                    'keypoints': preds[i].tolist(),\n                    'center': boxes[i][0:2].tolist(),\n                    'scale': boxes[i][2:4].tolist(),\n                    'area': float(boxes[i][4]),\n                    'score': float(boxes[i][5]),\n                    'image_id': image_id,\n                    'bbox_id': bbox_ids[i]\n                })\n        kpts = self._sort_and_unique_bboxes(kpts)\n\n        self._write_keypoint_results(kpts, res_file)\n        info_str = self._report_metric(res_file, metrics)\n        name_value = OrderedDict(info_str)\n\n        return name_value\n"
  },
  {
    "path": "models/datasets/datasets/mp100/transformer_base_dataset.py",
    "content": "import copy\nfrom abc import ABCMeta, abstractmethod\n\nimport json_tricks as json\nimport numpy as np\nfrom mmcv.parallel import DataContainer as DC\nfrom mmpose.core.evaluation.top_down_eval import (keypoint_pck_accuracy)\nfrom mmpose.datasets import DATASETS\nfrom mmpose.datasets.pipelines import Compose\nfrom torch.utils.data import Dataset\n\n\n@DATASETS.register_module()\nclass TransformerBaseDataset(Dataset, metaclass=ABCMeta):\n\n    def __init__(self,\n                 ann_file,\n                 img_prefix,\n                 data_cfg,\n                 pipeline,\n                 test_mode=False):\n        self.image_info = {}\n        self.ann_info = {}\n\n        self.annotations_path = ann_file\n        if not img_prefix.endswith('/'):\n            img_prefix = img_prefix + '/'\n        self.img_prefix = img_prefix\n        self.pipeline = pipeline\n        self.test_mode = test_mode\n\n        self.ann_info['image_size'] = np.array(data_cfg['image_size'])\n        self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])\n        self.ann_info['num_joints'] = data_cfg['num_joints']\n\n        self.ann_info['flip_pairs'] = None\n\n        self.ann_info['inference_channel'] = data_cfg['inference_channel']\n        self.ann_info['num_output_channels'] = data_cfg['num_output_channels']\n        self.ann_info['dataset_channel'] = data_cfg['dataset_channel']\n\n        self.db = []\n        self.num_shots = 1\n        self.paired_samples = []\n        self.pipeline = Compose(self.pipeline)\n\n    @abstractmethod\n    def _get_db(self):\n        \"\"\"Load dataset.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _select_kpt(self, obj, kpt_id):\n        \"\"\"Select kpt.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):\n        \"\"\"Evaluate keypoint results.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def _write_keypoint_results(keypoints, res_file):\n        \"\"\"Write results into a json file.\"\"\"\n\n        with open(res_file, 'w') as f:\n            json.dump(keypoints, f, sort_keys=True, indent=4)\n\n    def _report_metric(self,\n                       res_file,\n                       metrics,\n                       pck_thr=0.2,\n                       pckh_thr=0.7,\n                       auc_nor=30):\n        \"\"\"Keypoint evaluation.\n\n        Args:\n            res_file (str): Json file stored prediction results.\n            metrics (str | list[str]): Metric to be performed.\n                Options: 'PCK', 'PCKh', 'AUC', 'EPE'.\n            pck_thr (float): PCK threshold, default as 0.2.\n            pckh_thr (float): PCKh threshold, default as 0.7.\n            auc_nor (float): AUC normalization factor, default as 30 pixel.\n\n        Returns:\n            List: Evaluation results for evaluation metric.\n        \"\"\"\n        info_str = []\n\n        with open(res_file, 'r') as fin:\n            preds = json.load(fin)\n        assert len(preds) == len(self.paired_samples)\n\n        outputs = []\n        gts = []\n        masks = []\n        threshold_bbox = []\n        threshold_head_box = []\n\n        for pred, pair in zip(preds, self.paired_samples):\n            item = self.db[pair[-1]]\n            outputs.append(np.array(pred['keypoints'])[:, :-1])\n            gts.append(np.array(item['joints_3d'])[:, :-1])\n\n            mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)\n            mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)\n            for id_s in pair[:-1]:\n                mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))\n            masks.append(np.bitwise_and(mask_query, mask_sample))\n\n            if 'PCK' in metrics:\n                bbox = np.array(item['bbox'])\n                bbox_thr = np.max(bbox[2:])\n                threshold_bbox.append(np.array([bbox_thr, bbox_thr]))\n            if 'PCKh' in metrics:\n                head_box_thr = item['head_size']\n                threshold_head_box.append(\n                    np.array([head_box_thr, head_box_thr]))\n\n        if 'PCK' in metrics:\n            pck_avg = []\n            for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):\n                _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt, 0),\n                                                  np.expand_dims(mask, 0), pck_thr, np.expand_dims(thr_bbox, 0))\n                pck_avg.append(pck)\n            info_str.append(('PCK', np.mean(pck_avg)))\n\n        return info_str\n\n    def _merge_obj(self, Xs_list, Xq, idx):\n        \"\"\" merge Xs_list and Xq.\n\n        :param Xs_list: N-shot samples X\n        :param Xq: query X\n        :param idx: id of paired_samples\n        :return: Xall\n        \"\"\"\n        Xall = dict()\n        Xall['img_s'] = [Xs['img'] for Xs in Xs_list]\n        Xall['target_s'] = [Xs['target'] for Xs in Xs_list]\n        Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]\n        xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]\n\n        Xall['img_q'] = Xq['img']\n        Xall['target_q'] = Xq['target']\n        Xall['target_weight_q'] = Xq['target_weight']\n        xq_img_metas = Xq['img_metas'].data\n\n        img_metas = dict()\n        for key in xq_img_metas.keys():\n            img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]\n            img_metas['query_' + key] = xq_img_metas[key]\n        img_metas['bbox_id'] = idx\n\n        Xall['img_metas'] = DC(img_metas, cpu_only=True)\n\n        return Xall\n\n    def __len__(self):\n        \"\"\"Get the size of the dataset.\"\"\"\n        return len(self.paired_samples)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the sample given index.\"\"\"\n\n        pair_ids = self.paired_samples[idx]  # [supported id * shots, query id]\n        assert len(pair_ids) == self.num_shots + 1\n        sample_id_list = pair_ids[:self.num_shots]\n        query_id = pair_ids[-1]\n\n        sample_obj_list = []\n        for sample_id in sample_id_list:\n            sample_obj = copy.deepcopy(self.db[sample_id])\n            sample_obj['ann_info'] = copy.deepcopy(self.ann_info)\n            sample_obj_list.append(sample_obj)\n\n        query_obj = copy.deepcopy(self.db[query_id])\n        query_obj['ann_info'] = copy.deepcopy(self.ann_info)\n\n        Xs_list = []\n        for sample_obj in sample_obj_list:\n            Xs = self.pipeline(sample_obj)  # dict with ['img', 'target', 'target_weight', 'img_metas'],\n            Xs_list.append(Xs)  # Xs['target'] is of shape [100, map_h, map_w]\n        Xq = self.pipeline(query_obj)\n\n        Xall = self._merge_obj(Xs_list, Xq, idx)\n        Xall['skeleton'] = self.db[query_id]['skeleton']\n        return Xall\n\n    def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):\n        \"\"\"sort kpts and remove the repeated ones.\"\"\"\n        kpts = sorted(kpts, key=lambda x: x[key])\n        num = len(kpts)\n        for i in range(num - 1, 0, -1):\n            if kpts[i][key] == kpts[i - 1][key]:\n                del kpts[i]\n\n        return kpts\n"
  },
  {
    "path": "models/datasets/datasets/mp100/transformer_dataset.py",
    "content": "import os\nimport random\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom mmpose.datasets import DATASETS\nfrom xtcocotools.coco import COCO\n\nfrom .transformer_base_dataset import TransformerBaseDataset\n\n\n@DATASETS.register_module()\nclass TransformerPoseDataset(TransformerBaseDataset):\n\n    def __init__(self,\n                 ann_file,\n                 img_prefix,\n                 data_cfg,\n                 pipeline,\n                 valid_class_ids,\n                 max_kpt_num=None,\n                 num_shots=1,\n                 num_queries=100,\n                 num_episodes=1,\n                 test_mode=False):\n        super().__init__(\n            ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)\n\n        self.ann_info['flip_pairs'] = []\n\n        self.ann_info['upper_body_ids'] = []\n        self.ann_info['lower_body_ids'] = []\n\n        self.ann_info['use_different_joint_weights'] = False\n        self.ann_info['joint_weights'] = np.array([1., ],\n                                                  dtype=np.float32).reshape((self.ann_info['num_joints'], 1))\n\n        self.coco = COCO(ann_file)\n\n        self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)\n        self.img_ids = self.coco.getImgIds()\n        self.classes = [\n            cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())\n        ]\n\n        self.num_classes = len(self.classes)\n        self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))\n        self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))\n\n        if valid_class_ids is not None:  # None by default\n            self.valid_class_ids = valid_class_ids\n        else:\n            self.valid_class_ids = self.coco.getCatIds()\n        self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]\n\n        self.cats = self.coco.cats\n        self.max_kpt_num = max_kpt_num\n\n        # Also update self.cat2obj\n        self.db = self._get_db()\n\n        self.num_shots = num_shots\n\n        if not test_mode:\n            # Update every training epoch\n            self.random_paired_samples()\n        else:\n            self.num_queries = num_queries\n            self.num_episodes = num_episodes\n            self.make_paired_samples()\n\n    def random_paired_samples(self):\n        num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]\n\n        # balance the dataset\n        max_num_data = max(num_datas)\n\n        all_samples = []\n        for cls in self.valid_class_ids:\n            for i in range(max_num_data):\n                shot = random.sample(self.cat2obj[cls], self.num_shots + 1)\n                all_samples.append(shot)\n\n        self.paired_samples = np.array(all_samples)\n        np.random.shuffle(self.paired_samples)\n\n    def make_paired_samples(self):\n        random.seed(1)\n        np.random.seed(0)\n\n        all_samples = []\n        for cls in self.valid_class_ids:\n            for _ in range(self.num_episodes):\n                shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)\n                sample_ids = shots[:self.num_shots]\n                query_ids = shots[self.num_shots:]\n                for query_id in query_ids:\n                    all_samples.append(sample_ids + [query_id])\n\n        self.paired_samples = np.array(all_samples)\n\n    def _select_kpt(self, obj, kpt_id):\n        obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id + 1]\n        obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id + 1]\n        obj['kpt_id'] = kpt_id\n\n        return obj\n\n    @staticmethod\n    def _get_mapping_id_name(imgs):\n        \"\"\"\n        Args:\n            imgs (dict): dict of image info.\n\n        Returns:\n            tuple: Image name & id mapping dicts.\n\n            - id2name (dict): Mapping image id to name.\n            - name2id (dict): Mapping image name to id.\n        \"\"\"\n        id2name = {}\n        name2id = {}\n        for image_id, image in imgs.items():\n            file_name = image['file_name']\n            id2name[image_id] = file_name\n            name2id[file_name] = image_id\n\n        return id2name, name2id\n\n    def _get_db(self):\n        \"\"\"Ground truth bbox and keypoints.\"\"\"\n        self.obj_id = 0\n\n        self.cat2obj = {}\n        for i in self.coco.getCatIds():\n            self.cat2obj.update({i: []})\n\n        gt_db = []\n        for img_id in self.img_ids:\n            gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))\n\n        return gt_db\n\n    def _load_coco_keypoint_annotation_kernel(self, img_id):\n        \"\"\"load annotation from COCOAPI.\n\n        Note:\n            bbox:[x1, y1, w, h]\n        Args:\n            img_id: coco image id\n        Returns:\n            dict: db entry\n        \"\"\"\n        img_ann = self.coco.loadImgs(img_id)[0]\n        width = img_ann['width']\n        height = img_ann['height']\n\n        ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)\n        objs = self.coco.loadAnns(ann_ids)\n\n        # sanitize bboxes\n        valid_objs = []\n        for obj in objs:\n            if 'bbox' not in obj:\n                continue\n            x, y, w, h = obj['bbox']\n            x1 = max(0, x)\n            y1 = max(0, y)\n            x2 = min(width - 1, x1 + max(0, w - 1))\n            y2 = min(height - 1, y1 + max(0, h - 1))\n            if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:\n                obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]\n                valid_objs.append(obj)\n        objs = valid_objs\n\n        bbox_id = 0\n        rec = []\n        for obj in objs:\n            if 'keypoints' not in obj:\n                continue\n            if max(obj['keypoints']) == 0:\n                continue\n            if 'num_keypoints' in obj and obj['num_keypoints'] == 0:\n                continue\n\n            category_id = obj['category_id']\n            # the number of keypoint for this specific category\n            cat_kpt_num = int(len(obj['keypoints']) / 3)\n            if self.max_kpt_num is None:\n                kpt_num = cat_kpt_num\n            else:\n                kpt_num = self.max_kpt_num\n\n            joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)\n            joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)\n\n            keypoints = np.array(obj['keypoints']).reshape(-1, 3)\n            joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]\n            joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])\n\n            center, scale = self._xywh2cs(*obj['clean_bbox'][:4])\n\n            image_file = os.path.join(self.img_prefix, self.id2name[img_id])\n            if os.path.exists(image_file):\n                self.cat2obj[category_id].append(self.obj_id)\n\n                rec.append({\n                    'image_file': image_file,\n                    'center': center,\n                    'scale': scale,\n                    'rotation': 0,\n                    'bbox': obj['clean_bbox'][:4],\n                    'bbox_score': 1,\n                    'joints_3d': joints_3d,\n                    'joints_3d_visible': joints_3d_visible,\n                    'category_id': category_id,\n                    'cat_kpt_num': cat_kpt_num,\n                    'bbox_id': self.obj_id,\n                    'skeleton': self.coco.cats[obj['category_id']]['skeleton'],\n                })\n                bbox_id = bbox_id + 1\n                self.obj_id += 1\n\n        return rec\n\n    def _xywh2cs(self, x, y, w, h):\n        \"\"\"This encodes bbox(x,y,w,w) into (center, scale)\n\n        Args:\n            x, y, w, h\n\n        Returns:\n            tuple: A tuple containing center and scale.\n\n            - center (np.ndarray[float32](2,)): center of the bbox (x, y).\n            - scale (np.ndarray[float32](2,)): scale of the bbox w & h.\n        \"\"\"\n        aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]\n        center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)\n        #\n        # if (not self.test_mode) and np.random.rand() < 0.3:\n        #     center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]\n\n        if w > aspect_ratio * h:\n            h = w * 1.0 / aspect_ratio\n        elif w < aspect_ratio * h:\n            w = h * aspect_ratio\n\n        # pixel std is 200.0\n        scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)\n        # padding to include proper amount of context\n        scale = scale * 1.25\n\n        return center, scale\n\n    def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):\n        \"\"\"Evaluate interhand2d keypoint results. The pose prediction results\n        will be saved in `${res_folder}/result_keypoints.json`.\n\n        Note:\n            batch_size: N\n            num_keypoints: K\n            heatmap height: H\n            heatmap width: W\n\n        Args:\n            outputs (list(preds, boxes, image_path, output_heatmap))\n                :preds (np.ndarray[N,K,3]): The first two dimensions are\n                    coordinates, score is the third dimension of the array.\n                :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]\n                    , scale[1],area, score]\n                :image_paths (list[str]): For example, ['C', 'a', 'p', 't',\n                    'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',\n                    'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',\n                    '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',\n                    'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',\n                    'j', 'p', 'g']\n                :output_heatmap (np.ndarray[N, K, H, W]): model outpus.\n\n            res_folder (str): Path of directory to save the results.\n            metric (str | list[str]): Metric to be performed.\n                Options: 'PCK', 'AUC', 'EPE'.\n\n        Returns:\n            dict: Evaluation results for evaluation metric.\n        \"\"\"\n        metrics = metric if isinstance(metric, list) else [metric]\n        allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']\n        for metric in metrics:\n            if metric not in allowed_metrics:\n                raise KeyError(f'metric {metric} is not supported')\n\n        res_file = os.path.join(res_folder, 'result_keypoints.json')\n\n        kpts = []\n        for output in outputs:\n            preds = output['preds']\n            boxes = output['boxes']\n            image_paths = output['image_paths']\n            bbox_ids = output['bbox_ids']\n\n            batch_size = len(image_paths)\n            for i in range(batch_size):\n                image_id = self.name2id[image_paths[i][len(self.img_prefix):]]\n\n                kpts.append({\n                    'keypoints': preds[i].tolist(),\n                    'center': boxes[i][0:2].tolist(),\n                    'scale': boxes[i][2:4].tolist(),\n                    'area': float(boxes[i][4]),\n                    'score': float(boxes[i][5]),\n                    'image_id': image_id,\n                    'bbox_id': bbox_ids[i]\n                })\n        kpts = self._sort_and_unique_bboxes(kpts)\n\n        self._write_keypoint_results(kpts, res_file)\n        info_str = self._report_metric(res_file, metrics)\n        name_value = OrderedDict(info_str)\n\n        return name_value\n"
  },
  {
    "path": "models/datasets/pipelines/__init__.py",
    "content": "from .top_down_transform import (TopDownAffineFewShot,\n                                 TopDownGenerateTargetFewShot)\n\n__all__ = [\n    'TopDownGenerateTargetFewShot', 'TopDownAffineFewShot'\n]\n"
  },
  {
    "path": "models/datasets/pipelines/post_transforms.py",
    "content": "# ------------------------------------------------------------------------------\n# Adapted from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch\n# Original licence: Copyright (c) Microsoft, under the MIT License.\n# ------------------------------------------------------------------------------\n\nimport cv2\nimport numpy as np\n\n\ndef get_affine_transform(center,\n                         scale,\n                         rot,\n                         output_size,\n                         shift=(0., 0.),\n                         inv=False):\n    \"\"\"Get the affine transform matrix, given the center/scale/rot/output_size.\n\n    Args:\n        center (np.ndarray[2, ]): Center of the bounding box (x, y).\n        scale (np.ndarray[2, ]): Scale of the bounding box\n            wrt [width, height].\n        rot (float): Rotation angle (degree).\n        output_size (np.ndarray[2, ]): Size of the destination heatmaps.\n        shift (0-100%): Shift translation ratio wrt the width/height.\n            Default (0., 0.).\n        inv (bool): Option to inverse the affine transform direction.\n            (inv=False: src->dst or inv=True: dst->src)\n\n    Returns:\n        np.ndarray: The transform matrix.\n    \"\"\"\n    assert len(center) == 2\n    assert len(scale) == 2\n    assert len(output_size) == 2\n    assert len(shift) == 2\n\n    # pixel_std is 200.\n    scale_tmp = scale * 200.0\n\n    shift = np.array(shift)\n    src_w = scale_tmp[0]\n    dst_w = output_size[0]\n    dst_h = output_size[1]\n\n    rot_rad = np.pi * rot / 180\n    src_dir = rotate_point([0., src_w * -0.5], rot_rad)\n    dst_dir = np.array([0., dst_w * -0.5])\n\n    src = np.zeros((3, 2), dtype=np.float32)\n    src[0, :] = center + scale_tmp * shift\n    src[1, :] = center + src_dir + scale_tmp * shift\n    src[2, :] = _get_3rd_point(src[0, :], src[1, :])\n\n    dst = np.zeros((3, 2), dtype=np.float32)\n    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]\n    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir\n    dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])\n\n    if inv:\n        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))\n    else:\n        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))\n\n    return trans\n\n\ndef affine_transform(pt, trans_mat):\n    \"\"\"Apply an affine transformation to the points.\n\n    Args:\n        pt (np.ndarray): a 2 dimensional point to be transformed\n        trans_mat (np.ndarray): 2x3 matrix of an affine transform\n\n    Returns:\n        np.ndarray: Transformed points.\n    \"\"\"\n    assert len(pt) == 2\n    new_pt = np.array(trans_mat) @ np.array([pt[0], pt[1], 1.])\n\n    return new_pt\n\n\ndef _get_3rd_point(a, b):\n    \"\"\"To calculate the affine matrix, three pairs of points are required. This\n    function is used to get the 3rd point, given 2D points a & b.\n\n    The 3rd point is defined by rotating vector `a - b` by 90 degrees\n    anticlockwise, using b as the rotation center.\n\n    Args:\n        a (np.ndarray): point(x,y)\n        b (np.ndarray): point(x,y)\n\n    Returns:\n        np.ndarray: The 3rd point.\n    \"\"\"\n    assert len(a) == 2\n    assert len(b) == 2\n    direction = a - b\n    third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)\n\n    return third_pt\n\n\ndef rotate_point(pt, angle_rad):\n    \"\"\"Rotate a point by an angle.\n\n    Args:\n        pt (list[float]): 2 dimensional point to be rotated\n        angle_rad (float): rotation angle by radian\n\n    Returns:\n        list[float]: Rotated point.\n    \"\"\"\n    assert len(pt) == 2\n    sn, cs = np.sin(angle_rad), np.cos(angle_rad)\n    new_x = pt[0] * cs - pt[1] * sn\n    new_y = pt[0] * sn + pt[1] * cs\n    rotated_pt = [new_x, new_y]\n\n    return rotated_pt\n"
  },
  {
    "path": "models/datasets/pipelines/top_down_transform.py",
    "content": "import cv2\nimport numpy as np\nfrom mmpose.core.post_processing import (get_warp_matrix,\n                                         warp_affine_joints)\nfrom mmpose.datasets.builder import PIPELINES\n\nfrom .post_transforms import (affine_transform,\n                              get_affine_transform)\n\n\n@PIPELINES.register_module()\nclass TopDownAffineFewShot:\n    \"\"\"Affine transform the image to make input.\n\n    Required keys:'img', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',\n    'rotation' and 'center'. Modified keys:'img', 'joints_3d', and\n    'joints_3d_visible'.\n\n    Args:\n        use_udp (bool): To use unbiased data processing.\n            Paper ref: Huang et al. The Devil is in the Details: Delving into\n            Unbiased Data Processing for Human Pose Estimation (CVPR 2020).\n    \"\"\"\n\n    def __init__(self, use_udp=False):\n        self.use_udp = use_udp\n\n    def __call__(self, results):\n        image_size = results['ann_info']['image_size']\n\n        img = results['img']\n        joints_3d = results['joints_3d']\n        joints_3d_visible = results['joints_3d_visible']\n        c = results['center']\n        s = results['scale']\n        r = results['rotation']\n\n        if self.use_udp:\n            trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)\n            img = cv2.warpAffine(\n                img,\n                trans, (int(image_size[0]), int(image_size[1])),\n                flags=cv2.INTER_LINEAR)\n            joints_3d[:, 0:2] = \\\n                warp_affine_joints(joints_3d[:, 0:2].copy(), trans)\n        else:\n            trans = get_affine_transform(c, s, r, image_size)\n            img = cv2.warpAffine(\n                img,\n                trans, (int(image_size[0]), int(image_size[1])),\n                flags=cv2.INTER_LINEAR)\n            for i in range(len(joints_3d)):\n                if joints_3d_visible[i, 0] > 0.0:\n                    joints_3d[i,\n                    0:2] = affine_transform(joints_3d[i, 0:2], trans)\n\n        results['img'] = img\n        results['joints_3d'] = joints_3d\n        results['joints_3d_visible'] = joints_3d_visible\n\n        return results\n\n\n@PIPELINES.register_module()\nclass TopDownGenerateTargetFewShot:\n    \"\"\"Generate the target heatmap.\n\n    Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'.\n    Modified keys: 'target', and 'target_weight'.\n\n    Args:\n        sigma: Sigma of heatmap gaussian for 'MSRA' approach.\n        kernel: Kernel of heatmap gaussian for 'Megvii' approach.\n        encoding (str): Approach to generate target heatmaps.\n            Currently supported approaches: 'MSRA', 'Megvii', 'UDP'.\n            Default:'MSRA'\n\n        unbiased_encoding (bool): Option to use unbiased\n            encoding methods.\n            Paper ref: Zhang et al. Distribution-Aware Coordinate\n            Representation for Human Pose Estimation (CVPR 2020).\n        keypoint_pose_distance: Keypoint pose distance for UDP.\n            Paper ref: Huang et al. The Devil is in the Details: Delving into\n            Unbiased Data Processing for Human Pose Estimation (CVPR 2020).\n        target_type (str): supported targets: 'GaussianHeatMap',\n            'CombinedTarget'. Default:'GaussianHeatMap'\n            CombinedTarget: The combination of classification target\n            (response map) and regression target (offset map).\n            Paper ref: Huang et al. The Devil is in the Details: Delving into\n            Unbiased Data Processing for Human Pose Estimation (CVPR 2020).\n    \"\"\"\n\n    def __init__(self,\n                 sigma=2,\n                 kernel=(11, 11),\n                 valid_radius_factor=0.0546875,\n                 target_type='GaussianHeatMap',\n                 encoding='MSRA',\n                 unbiased_encoding=False):\n        self.sigma = sigma\n        self.unbiased_encoding = unbiased_encoding\n        self.kernel = kernel\n        self.valid_radius_factor = valid_radius_factor\n        self.target_type = target_type\n        self.encoding = encoding\n\n    def _msra_generate_target(self, cfg, joints_3d, joints_3d_visible, sigma):\n        \"\"\"Generate the target heatmap via \"MSRA\" approach.\n\n        Args:\n            cfg (dict): data config\n            joints_3d: np.ndarray ([num_joints, 3])\n            joints_3d_visible: np.ndarray ([num_joints, 3])\n            sigma: Sigma of heatmap gaussian\n        Returns:\n            tuple: A tuple containing targets.\n\n            - target: Target heatmaps.\n            - target_weight: (1: visible, 0: invisible)\n        \"\"\"\n        num_joints = len(joints_3d)\n        image_size = cfg['image_size']\n        W, H = cfg['heatmap_size']\n        joint_weights = cfg['joint_weights']\n        use_different_joint_weights = cfg['use_different_joint_weights']\n        assert not use_different_joint_weights\n\n        target_weight = np.zeros((num_joints, 1), dtype=np.float32)\n        target = np.zeros((num_joints, H, W), dtype=np.float32)\n\n        # 3-sigma rule\n        tmp_size = sigma * 3\n\n        if self.unbiased_encoding:\n            for joint_id in range(num_joints):\n                target_weight[joint_id] = joints_3d_visible[joint_id, 0]\n\n                feat_stride = image_size / [W, H]\n                mu_x = joints_3d[joint_id][0] / feat_stride[0]\n                mu_y = joints_3d[joint_id][1] / feat_stride[1]\n                # Check that any part of the gaussian is in-bounds\n                ul = [mu_x - tmp_size, mu_y - tmp_size]\n                br = [mu_x + tmp_size + 1, mu_y + tmp_size + 1]\n                if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:\n                    target_weight[joint_id] = 0\n\n                if target_weight[joint_id] == 0:\n                    continue\n\n                x = np.arange(0, W, 1, np.float32)\n                y = np.arange(0, H, 1, np.float32)\n                y = y[:, None]\n\n                if target_weight[joint_id] > 0.5:\n                    target[joint_id] = np.exp(-((x - mu_x) ** 2 +\n                                                (y - mu_y) ** 2) /\n                                              (2 * sigma ** 2))\n        else:\n            for joint_id in range(num_joints):\n                target_weight[joint_id] = joints_3d_visible[joint_id, 0]\n\n                feat_stride = image_size / [W, H]\n                mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)\n                mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)\n                # Check that any part of the gaussian is in-bounds\n                ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]\n                br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]\n                if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:\n                    target_weight[joint_id] = 0\n\n                if target_weight[joint_id] > 0.5:\n                    size = 2 * tmp_size + 1\n                    x = np.arange(0, size, 1, np.float32)\n                    y = x[:, None]\n                    x0 = y0 = size // 2\n                    # The gaussian is not normalized,\n                    # we want the center value to equal 1\n                    g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\n\n                    # Usable gaussian range\n                    g_x = max(0, -ul[0]), min(br[0], W) - ul[0]\n                    g_y = max(0, -ul[1]), min(br[1], H) - ul[1]\n                    # Image range\n                    img_x = max(0, ul[0]), min(br[0], W)\n                    img_y = max(0, ul[1]), min(br[1], H)\n\n                    target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \\\n                        g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\n\n        if use_different_joint_weights:\n            target_weight = np.multiply(target_weight, joint_weights)\n\n        return target, target_weight\n\n    def _udp_generate_target(self, cfg, joints_3d, joints_3d_visible, factor,\n                             target_type):\n        \"\"\"Generate the target heatmap via 'UDP' approach. Paper ref: Huang et\n        al. The Devil is in the Details: Delving into Unbiased Data Processing\n        for Human Pose Estimation (CVPR 2020).\n\n        Note:\n            num keypoints: K\n            heatmap height: H\n            heatmap width: W\n            num target channels: C\n            C = K if target_type=='GaussianHeatMap'\n            C = 3*K if target_type=='CombinedTarget'\n\n        Args:\n            cfg (dict): data config\n            joints_3d (np.ndarray[K, 3]): Annotated keypoints.\n            joints_3d_visible (np.ndarray[K, 3]): Visibility of keypoints.\n            factor (float): kernel factor for GaussianHeatMap target or\n                valid radius factor for CombinedTarget.\n            target_type (str): 'GaussianHeatMap' or 'CombinedTarget'.\n                GaussianHeatMap: Heatmap target with gaussian distribution.\n                CombinedTarget: The combination of classification target\n                (response map) and regression target (offset map).\n\n        Returns:\n            tuple: A tuple containing targets.\n\n            - target (np.ndarray[C, H, W]): Target heatmaps.\n            - target_weight (np.ndarray[K, 1]): (1: visible, 0: invisible)\n        \"\"\"\n        num_joints = len(joints_3d)\n        image_size = cfg['image_size']\n        heatmap_size = cfg['heatmap_size']\n        joint_weights = cfg['joint_weights']\n        use_different_joint_weights = cfg['use_different_joint_weights']\n        assert not use_different_joint_weights\n\n        target_weight = np.ones((num_joints, 1), dtype=np.float32)\n        target_weight[:, 0] = joints_3d_visible[:, 0]\n\n        assert target_type in ['GaussianHeatMap', 'CombinedTarget']\n\n        if target_type == 'GaussianHeatMap':\n            target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),\n                              dtype=np.float32)\n\n            tmp_size = factor * 3\n\n            # prepare for gaussian\n            size = 2 * tmp_size + 1\n            x = np.arange(0, size, 1, np.float32)\n            y = x[:, None]\n\n            for joint_id in range(num_joints):\n                feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)\n                mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)\n                mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)\n                # Check that any part of the gaussian is in-bounds\n                ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]\n                br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]\n                if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \\\n                        or br[0] < 0 or br[1] < 0:\n                    # If not, just return the image as is\n                    target_weight[joint_id] = 0\n                    continue\n\n                # # Generate gaussian\n                mu_x_ac = joints_3d[joint_id][0] / feat_stride[0]\n                mu_y_ac = joints_3d[joint_id][1] / feat_stride[1]\n                x0 = y0 = size // 2\n                x0 += mu_x_ac - mu_x\n                y0 += mu_y_ac - mu_y\n                g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * factor ** 2))\n\n                # Usable gaussian range\n                g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]\n                g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]\n                # Image range\n                img_x = max(0, ul[0]), min(br[0], heatmap_size[0])\n                img_y = max(0, ul[1]), min(br[1], heatmap_size[1])\n\n                v = target_weight[joint_id]\n                if v > 0.5:\n                    target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \\\n                        g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\n        elif target_type == 'CombinedTarget':\n            target = np.zeros(\n                (num_joints, 3, heatmap_size[1] * heatmap_size[0]),\n                dtype=np.float32)\n            feat_width = heatmap_size[0]\n            feat_height = heatmap_size[1]\n            feat_x_int = np.arange(0, feat_width)\n            feat_y_int = np.arange(0, feat_height)\n            feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)\n            feat_x_int = feat_x_int.flatten()\n            feat_y_int = feat_y_int.flatten()\n            # Calculate the radius of the positive area in classification\n            #   heatmap.\n            valid_radius = factor * heatmap_size[1]\n            feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)\n            for joint_id in range(num_joints):\n                mu_x = joints_3d[joint_id][0] / feat_stride[0]\n                mu_y = joints_3d[joint_id][1] / feat_stride[1]\n                x_offset = (mu_x - feat_x_int) / valid_radius\n                y_offset = (mu_y - feat_y_int) / valid_radius\n                dis = x_offset ** 2 + y_offset ** 2\n                keep_pos = np.where(dis <= 1)[0]\n                v = target_weight[joint_id]\n                if v > 0.5:\n                    target[joint_id, 0, keep_pos] = 1\n                    target[joint_id, 1, keep_pos] = x_offset[keep_pos]\n                    target[joint_id, 2, keep_pos] = y_offset[keep_pos]\n            target = target.reshape(num_joints * 3, heatmap_size[1],\n                                    heatmap_size[0])\n\n        if use_different_joint_weights:\n            target_weight = np.multiply(target_weight, joint_weights)\n\n        return target, target_weight\n\n    def __call__(self, results):\n        \"\"\"Generate the target heatmap.\"\"\"\n        joints_3d = results['joints_3d']\n        joints_3d_visible = results['joints_3d_visible']\n\n        assert self.encoding in ['MSRA', 'UDP']\n\n        if self.encoding == 'MSRA':\n            if isinstance(self.sigma, list):\n                num_sigmas = len(self.sigma)\n                cfg = results['ann_info']\n                num_joints = len(joints_3d)\n                heatmap_size = cfg['heatmap_size']\n\n                target = np.empty(\n                    (0, num_joints, heatmap_size[1], heatmap_size[0]),\n                    dtype=np.float32)\n                target_weight = np.empty((0, num_joints, 1), dtype=np.float32)\n                for i in range(num_sigmas):\n                    target_i, target_weight_i = self._msra_generate_target(\n                        cfg, joints_3d, joints_3d_visible, self.sigma[i])\n                    target = np.concatenate([target, target_i[None]], axis=0)\n                    target_weight = np.concatenate(\n                        [target_weight, target_weight_i[None]], axis=0)\n            else:\n                target, target_weight = self._msra_generate_target(\n                    results['ann_info'], joints_3d, joints_3d_visible,\n                    self.sigma)\n        elif self.encoding == 'UDP':\n            if self.target_type == 'CombinedTarget':\n                factors = self.valid_radius_factor\n                channel_factor = 3\n            elif self.target_type == 'GaussianHeatMap':\n                factors = self.sigma\n                channel_factor = 1\n            if isinstance(factors, list):\n                num_factors = len(factors)\n                cfg = results['ann_info']\n                num_joints = len(joints_3d)\n                W, H = cfg['heatmap_size']\n\n                target = np.empty((0, channel_factor * num_joints, H, W),\n                                  dtype=np.float32)\n                target_weight = np.empty((0, num_joints, 1), dtype=np.float32)\n                for i in range(num_factors):\n                    target_i, target_weight_i = self._udp_generate_target(\n                        cfg, joints_3d, joints_3d_visible, factors[i],\n                        self.target_type)\n                    target = np.concatenate([target, target_i[None]], axis=0)\n                    target_weight = np.concatenate(\n                        [target_weight, target_weight_i[None]], axis=0)\n            else:\n                target, target_weight = self._udp_generate_target(\n                    results['ann_info'], joints_3d, joints_3d_visible, factors,\n                    self.target_type)\n        else:\n            raise ValueError(\n                f'Encoding approach {self.encoding} is not supported!')\n\n        results['target'] = target\n        results['target_weight'] = target_weight\n\n        return results\n"
  },
  {
    "path": "models/models/__init__.py",
    "content": "from .backbones import *  # noqa\nfrom .detectors import *  # noqa\nfrom .keypoint_heads import *  # noqa\n"
  },
  {
    "path": "models/models/backbones/__init__.py",
    "content": "from .swin_transformer_v2 import SwinTransformerV2\n"
  },
  {
    "path": "models/models/backbones/simmim.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Zhenda Xie\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import trunc_normal_\n\nfrom .swin_transformer import SwinTransformer\nfrom .swin_transformer_v2 import SwinTransformerV2\n\n\ndef norm_targets(targets, patch_size):\n    assert patch_size % 2 == 1\n\n    targets_ = targets\n    targets_count = torch.ones_like(targets)\n\n    targets_square = targets ** 2.\n\n    targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2,\n                                count_include_pad=False)\n    targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2,\n                                       count_include_pad=False)\n    targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2,\n                                 count_include_pad=True) * (patch_size ** 2)\n\n    targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))\n    targets_var = torch.clamp(targets_var, min=0.)\n\n    targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5\n\n    return targets_\n\n\nclass SwinTransformerForSimMIM(SwinTransformer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        assert self.num_classes == 0\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        trunc_normal_(self.mask_token, mean=0., std=.02)\n\n    def forward(self, x, mask):\n        x = self.patch_embed(x)\n\n        assert mask is not None\n        B, L, _ = x.shape\n\n        mask_tokens = self.mask_token.expand(B, L, -1)\n        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)\n        x = x * (1. - w) + mask_tokens * w\n\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n        x = self.norm(x)\n\n        x = x.transpose(1, 2)\n        B, C, L = x.shape\n        H = W = int(L ** 0.5)\n        x = x.reshape(B, C, H, W)\n        return x\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return super().no_weight_decay() | {'mask_token'}\n\n\nclass SwinTransformerV2ForSimMIM(SwinTransformerV2):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        assert self.num_classes == 0\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        trunc_normal_(self.mask_token, mean=0., std=.02)\n\n    def forward(self, x, mask):\n        x = self.patch_embed(x)\n\n        assert mask is not None\n        B, L, _ = x.shape\n\n        mask_tokens = self.mask_token.expand(B, L, -1)\n        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)\n        x = x * (1. - w) + mask_tokens * w\n\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n        x = self.norm(x)\n\n        x = x.transpose(1, 2)\n        B, C, L = x.shape\n        H = W = int(L ** 0.5)\n        x = x.reshape(B, C, H, W)\n        return x\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return super().no_weight_decay() | {'mask_token'}\n\n\nclass SimMIM(nn.Module):\n    def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):\n        super().__init__()\n        self.config = config\n        self.encoder = encoder\n        self.encoder_stride = encoder_stride\n\n        self.decoder = nn.Sequential(\n            nn.Conv2d(\n                in_channels=self.encoder.num_features,\n                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),\n            nn.PixelShuffle(self.encoder_stride),\n        )\n\n        self.in_chans = in_chans\n        self.patch_size = patch_size\n\n    def forward(self, x, mask):\n        z = self.encoder(x, mask)\n        x_rec = self.decoder(z)\n\n        mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(\n            1).contiguous()\n\n        # norm target as prompted\n        if self.config.NORM_TARGET.ENABLE:\n            x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE)\n\n        loss_recon = F.l1_loss(x, x_rec, reduction='none')\n        loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans\n        return loss\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        if hasattr(self.encoder, 'no_weight_decay'):\n            return {'encoder.' + i for i in self.encoder.no_weight_decay()}\n        return {}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        if hasattr(self.encoder, 'no_weight_decay_keywords'):\n            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}\n        return {}\n\n\ndef build_simmim(config):\n    model_type = config.MODEL.TYPE\n    if model_type == 'swin':\n        encoder = SwinTransformerForSimMIM(\n            img_size=config.DATA.IMG_SIZE,\n            patch_size=config.MODEL.SWIN.PATCH_SIZE,\n            in_chans=config.MODEL.SWIN.IN_CHANS,\n            num_classes=0,\n            embed_dim=config.MODEL.SWIN.EMBED_DIM,\n            depths=config.MODEL.SWIN.DEPTHS,\n            num_heads=config.MODEL.SWIN.NUM_HEADS,\n            window_size=config.MODEL.SWIN.WINDOW_SIZE,\n            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,\n            qkv_bias=config.MODEL.SWIN.QKV_BIAS,\n            qk_scale=config.MODEL.SWIN.QK_SCALE,\n            drop_rate=config.MODEL.DROP_RATE,\n            drop_path_rate=config.MODEL.DROP_PATH_RATE,\n            ape=config.MODEL.SWIN.APE,\n            patch_norm=config.MODEL.SWIN.PATCH_NORM,\n            use_checkpoint=config.TRAIN.USE_CHECKPOINT)\n        encoder_stride = 32\n        in_chans = config.MODEL.SWIN.IN_CHANS\n        patch_size = config.MODEL.SWIN.PATCH_SIZE\n    elif model_type == 'swinv2':\n        encoder = SwinTransformerV2ForSimMIM(\n            img_size=config.DATA.IMG_SIZE,\n            patch_size=config.MODEL.SWINV2.PATCH_SIZE,\n            in_chans=config.MODEL.SWINV2.IN_CHANS,\n            num_classes=0,\n            embed_dim=config.MODEL.SWINV2.EMBED_DIM,\n            depths=config.MODEL.SWINV2.DEPTHS,\n            num_heads=config.MODEL.SWINV2.NUM_HEADS,\n            window_size=config.MODEL.SWINV2.WINDOW_SIZE,\n            mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,\n            qkv_bias=config.MODEL.SWINV2.QKV_BIAS,\n            drop_rate=config.MODEL.DROP_RATE,\n            drop_path_rate=config.MODEL.DROP_PATH_RATE,\n            ape=config.MODEL.SWINV2.APE,\n            patch_norm=config.MODEL.SWINV2.PATCH_NORM,\n            use_checkpoint=config.TRAIN.USE_CHECKPOINT)\n        encoder_stride = 32\n        in_chans = config.MODEL.SWINV2.IN_CHANS\n        patch_size = config.MODEL.SWINV2.PATCH_SIZE\n    else:\n        raise NotImplementedError(f\"Unknown pre-train model: {model_type}\")\n\n    model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans,\n                   patch_size=patch_size)\n\n    return model\n"
  },
  {
    "path": "models/models/backbones/swin_mlp.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass SwinMLPBlock(nn.Module):\n    r\"\"\" Swin MLP Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.padding = [self.window_size - self.shift_size, self.shift_size,\n                        self.window_size - self.shift_size, self.shift_size]  # P_l,P_r,P_t,P_b\n\n        self.norm1 = norm_layer(dim)\n        # use group convolution to implement multi-head MLP\n        self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,\n                                     self.num_heads * self.window_size ** 2,\n                                     kernel_size=1,\n                                     groups=self.num_heads)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # shift\n        if self.shift_size > 0:\n            P_l, P_r, P_t, P_b = self.padding\n            shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], \"constant\", 0)\n        else:\n            shifted_x = x\n        _, _H, _W, _ = shifted_x.shape\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # Window/Shifted-Window Spatial MLP\n        x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads)\n        x_windows_heads = x_windows_heads.transpose(1, 2)  # nW*B, nH, window_size*window_size, C//nH\n        x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,\n                                                  C // self.num_heads)\n        spatial_mlp_windows = self.spatial_mlp(x_windows_heads)  # nW*B, nH*window_size*window_size, C//nH\n        spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,\n                                                       C // self.num_heads).transpose(1, 2)\n        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)\n\n        # merge windows\n        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W)  # B H' W' C\n\n        # reverse shift\n        if self.shift_size > 0:\n            P_l, P_r, P_t, P_b = self.padding\n            x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n\n        # Window/Shifted-Window Spatial MLP\n        if self.shift_size > 0:\n            nW = (H / self.window_size + 1) * (W / self.window_size + 1)\n        else:\n            nW = H * W / self.window_size / self.window_size\n        flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin MLP layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., drop=0., drop_path=0.,\n                 norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinMLPBlock(dim=dim, input_resolution=input_resolution,\n                         num_heads=num_heads, window_size=window_size,\n                         shift_size=0 if (i % 2 == 0) else window_size // 2,\n                         mlp_ratio=mlp_ratio,\n                         drop=drop,\n                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                         norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinMLP(nn.Module):\n    r\"\"\" Swin MLP\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin MLP layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        drop_rate (float): Dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               drop=drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv1d)):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/models/backbones/swin_transformer.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\ntry:\n    import os, sys\n\n    kernel_path = os.path.abspath(os.path.join('..'))\n    sys.path.append(kernel_path)\n    from kernels.window_process.window_process import WindowProcess, WindowProcessReverse\n\nexcept:\n    WindowProcess = None\n    WindowProcessReverse = None\n    print(\"[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.\")\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 fused_window_process=False):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n        self.fused_window_process = fused_window_process\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            if not self.fused_window_process:\n                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n                # partition windows\n                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n            else:\n                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)\n        else:\n            shifted_x = x\n            # partition windows\n            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            if not self.fused_window_process:\n                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n            else:\n                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)\n        else:\n            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        x = shortcut + self.drop_path(x)\n\n        # FFN\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 fused_window_process=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 fused_window_process=fused_window_process)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, fused_window_process=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint,\n                               fused_window_process=fused_window_process)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/models/backbones/swin_transformer_moe.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer MoE\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\ntry:\n    from tutel import moe as tutel_moe\nexcept:\n    tutel_moe = None\n    print(\"Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.\")\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,\n                 mlp_fc2_bias=True):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass MoEMlp(nn.Module):\n    def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25,\n                 cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True,\n                 gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02,\n                 mlp_fc2_bias=True):\n        super().__init__()\n\n        self.in_features = in_features\n        self.hidden_features = hidden_features\n        self.num_local_experts = num_local_experts\n        self.top_value = top_value\n        self.capacity_factor = capacity_factor\n        self.cosine_router = cosine_router\n        self.normalize_gate = normalize_gate\n        self.use_bpr = use_bpr\n        self.init_std = init_std\n        self.mlp_fc2_bias = mlp_fc2_bias\n\n        self.dist_rank = dist.get_rank()\n\n        self._dropout = nn.Dropout(p=moe_drop)\n\n        _gate_type = {'type': 'cosine_top' if cosine_router else 'top',\n                      'k': top_value, 'capacity_factor': capacity_factor,\n                      'gate_noise': gate_noise, 'fp32_gate': True}\n        if cosine_router:\n            _gate_type['proj_dim'] = cosine_router_dim\n            _gate_type['init_t'] = cosine_router_init_t\n        self._moe_layer = tutel_moe.moe_layer(\n            gate_type=_gate_type,\n            model_dim=in_features,\n            experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features,\n                     'activation_fn': lambda x: self._dropout(F.gelu(x))},\n            scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True),\n            seeds=(1, self.dist_rank + 1, self.dist_rank + 1),\n            batch_prioritized_routing=use_bpr,\n            normalize_gate=normalize_gate,\n            is_gshard_loss=is_gshard_loss,\n\n        )\n        if not self.mlp_fc2_bias:\n            self._moe_layer.experts.batched_fc2_bias.requires_grad = False\n\n    def forward(self, x):\n        x = self._moe_layer(x)\n        return x, x.l_aux\n\n    def extra_repr(self) -> str:\n        return f'[Statistics-{self.dist_rank}] param count for MoE, ' \\\n               f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \\\n               f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \\\n               f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}'\n\n    def _init_weights(self):\n        if hasattr(self._moe_layer, \"experts\"):\n            trunc_normal_(self._moe_layer.experts.batched_fc1_w, std=self.init_std)\n            trunc_normal_(self._moe_layer.experts.batched_fc2_w, std=self.init_std)\n            nn.init.constant_(self._moe_layer.experts.batched_fc1_bias, 0)\n            nn.init.constant_(self._moe_layer.experts.batched_fc2_bias, 0)\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,\n                 pretrained_window_size=[0, 0]):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.pretrained_window_size = pretrained_window_size\n        self.num_heads = num_heads\n\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # mlp to generate continuous relative position bias\n        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),\n                                     nn.ReLU(inplace=True),\n                                     nn.Linear(512, num_heads, bias=False))\n\n        # get relative_coords_table\n        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)\n        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)\n        relative_coords_table = torch.stack(\n            torch.meshgrid([relative_coords_h,\n                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2\n        if pretrained_window_size[0] > 0:\n            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)\n        else:\n            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)\n        relative_coords_table *= 8  # normalize to -8, 8\n        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(\n            torch.abs(relative_coords_table) + 1.0) / np.log2(8)\n\n        self.register_buffer(\"relative_coords_table\", relative_coords_table)\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)\n        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, ' \\\n               f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True\n        init_std: Initialization std. Default: 0.02\n        pretrained_window_size (int): Window size in pre-training.\n        is_moe (bool): If True, this block is a MoE block.\n        num_local_experts (int): number of local experts in each device (GPU). Default: 1\n        top_value (int): the value of k in top-k gating. Default: 1\n        capacity_factor (float): the capacity factor in MoE. Default: 1.25\n        cosine_router (bool): Whether to use cosine router. Default: False\n        normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False\n        use_bpr (bool): Whether to use batch-prioritized-routing. Default: True\n        is_gshard_loss (bool): If True, use Gshard balance loss.\n                               If False, use the load loss and importance loss in \"arXiv:1701.06538\". Default: False\n        gate_noise (float): the noise ratio in top-k gating. Default: 1.0\n        cosine_router_dim (int): Projection dimension in cosine router.\n        cosine_router_init_t (float): Initialization temperature in cosine router.\n        moe_drop (float): Dropout rate in MoE. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0,\n                 is_moe=False, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,\n                 normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,\n                 cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        self.is_moe = is_moe\n        self.capacity_factor = capacity_factor\n        self.top_value = top_value\n\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,\n            pretrained_window_size=to_2tuple(pretrained_window_size))\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        if self.is_moe:\n            self.mlp = MoEMlp(in_features=dim,\n                              hidden_features=mlp_hidden_dim,\n                              num_local_experts=num_local_experts,\n                              top_value=top_value,\n                              capacity_factor=capacity_factor,\n                              cosine_router=cosine_router,\n                              normalize_gate=normalize_gate,\n                              use_bpr=use_bpr,\n                              is_gshard_loss=is_gshard_loss,\n                              gate_noise=gate_noise,\n                              cosine_router_dim=cosine_router_dim,\n                              cosine_router_init_t=cosine_router_init_t,\n                              moe_drop=moe_drop,\n                              mlp_fc2_bias=mlp_fc2_bias,\n                              init_std=init_std)\n        else:\n            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,\n                           mlp_fc2_bias=mlp_fc2_bias)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        x = shortcut + self.drop_path(x)\n\n        # FFN\n        shortcut = x\n        x = self.norm2(x)\n        if self.is_moe:\n            x, l_aux = self.mlp(x)\n            x = shortcut + self.drop_path(x)\n            return x, l_aux\n        else:\n            x = shortcut + self.drop_path(self.mlp(x))\n            return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        if self.is_moe:\n            flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * self.capacity_factor * self.top_value\n        else:\n            flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True\n        init_std: Initialization std. Default: 0.02\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        pretrained_window_size (int): Local window size in pre-training.\n        moe_blocks (tuple(int)): The index of each MoE block.\n        num_local_experts (int): number of local experts in each device (GPU). Default: 1\n        top_value (int): the value of k in top-k gating. Default: 1\n        capacity_factor (float): the capacity factor in MoE. Default: 1.25\n        cosine_router (bool): Whether to use cosine router Default: False\n        normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False\n        use_bpr (bool): Whether to use batch-prioritized-routing. Default: True\n        is_gshard_loss (bool): If True, use Gshard balance loss.\n                               If False, use the load loss and importance loss in \"arXiv:1701.06538\". Default: False\n        gate_noise (float): the noise ratio in top-k gating. Default: 1.0\n        cosine_router_dim (int): Projection dimension in cosine router.\n        cosine_router_init_t (float): Initialization temperature in cosine router.\n        moe_drop (float): Dropout rate in MoE. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None,\n                 mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0,\n                 moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,\n                 normalize_gate=False, use_bpr=True, is_gshard_loss=True,\n                 cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 mlp_fc2_bias=mlp_fc2_bias,\n                                 init_std=init_std,\n                                 pretrained_window_size=pretrained_window_size,\n\n                                 is_moe=True if i in moe_block else False,\n                                 num_local_experts=num_local_experts,\n                                 top_value=top_value,\n                                 capacity_factor=capacity_factor,\n                                 cosine_router=cosine_router,\n                                 normalize_gate=normalize_gate,\n                                 use_bpr=use_bpr,\n                                 is_gshard_loss=is_gshard_loss,\n                                 gate_noise=gate_noise,\n                                 cosine_router_dim=cosine_router_dim,\n                                 cosine_router_init_t=cosine_router_init_t,\n                                 moe_drop=moe_drop)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        l_aux = 0.0\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                out = checkpoint.checkpoint(blk, x)\n            else:\n                out = blk(x)\n            if isinstance(out, tuple):\n                x = out[0]\n                cur_l_aux = out[1]\n                l_aux = cur_l_aux + l_aux\n            else:\n                x = out\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x, l_aux\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformerMoE(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True\n        init_std: Initialization std. Default: 0.02\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.\n        moe_blocks (tuple(tuple(int))): The index of each MoE block in each layer.\n        num_local_experts (int): number of local experts in each device (GPU). Default: 1\n        top_value (int): the value of k in top-k gating. Default: 1\n        capacity_factor (float): the capacity factor in MoE. Default: 1.25\n        cosine_router (bool): Whether to use cosine router Default: False\n        normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False\n        use_bpr (bool): Whether to use batch-prioritized-routing. Default: True\n        is_gshard_loss (bool): If True, use Gshard balance loss.\n                               If False, use the load loss and importance loss in \"arXiv:1701.06538\". Default: False\n        gate_noise (float): the noise ratio in top-k gating. Default: 1.0\n        cosine_router_dim (int): Projection dimension in cosine router.\n        cosine_router_init_t (float): Initialization temperature in cosine router.\n        moe_drop (float): Dropout rate in MoE. Default: 0.0\n        aux_loss_weight (float): auxiliary loss weight. Default: 0.1\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],\n                 moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25,\n                 cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,\n                 cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs):\n        super().__init__()\n        self._ddp_params_and_buffers_to_ignore = list()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n        self.init_std = init_std\n        self.aux_loss_weight = aux_loss_weight\n        self.num_local_experts = num_local_experts\n        self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \\\n            else dist.get_world_size() // (-num_local_experts)\n        self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts)\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=self.init_std)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               mlp_fc2_bias=mlp_fc2_bias,\n                               init_std=init_std,\n                               use_checkpoint=use_checkpoint,\n                               pretrained_window_size=pretrained_window_sizes[i_layer],\n\n                               moe_block=moe_blocks[i_layer],\n                               num_local_experts=num_local_experts,\n                               top_value=top_value,\n                               capacity_factor=capacity_factor,\n                               cosine_router=cosine_router,\n                               normalize_gate=normalize_gate,\n                               use_bpr=use_bpr,\n                               is_gshard_loss=is_gshard_loss,\n                               gate_noise=gate_noise,\n                               cosine_router_dim=cosine_router_dim,\n                               cosine_router_init_t=cosine_router_init_t,\n                               moe_drop=moe_drop)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=self.init_std)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, MoEMlp):\n            m._init_weights()\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {\"cpb_mlp\", 'relative_position_bias_table', 'fc1_bias', 'fc2_bias',\n                'temperature', 'cosine_projector', 'sim_matrix'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n        l_aux = 0.0\n        for layer in self.layers:\n            x, cur_l_aux = layer(x)\n            l_aux = cur_l_aux + l_aux\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x, l_aux\n\n    def forward(self, x):\n        x, l_aux = self.forward_features(x)\n        x = self.head(x)\n        return x, l_aux * self.aux_loss_weight\n\n    def add_param_to_skip_allreduce(self, param_name):\n        self._ddp_params_and_buffers_to_ignore.append(param_name)\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/models/backbones/swin_transformer_v2.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer V2\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom mmpose.models.builder import BACKBONES\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,\n                 pretrained_window_size=[0, 0]):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.pretrained_window_size = pretrained_window_size\n        self.num_heads = num_heads\n\n        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)\n\n        # mlp to generate continuous relative position bias\n        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),\n                                     nn.ReLU(inplace=True),\n                                     nn.Linear(512, num_heads, bias=False))\n\n        # get relative_coords_table\n        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)\n        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)\n        relative_coords_table = torch.stack(\n            torch.meshgrid([relative_coords_h,\n                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2\n        if pretrained_window_size[0] > 0:\n            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)\n        else:\n            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)\n        relative_coords_table *= 8  # normalize to -8, 8\n        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(\n            torch.abs(relative_coords_table) + 1.0) / np.log2(8)\n\n        self.register_buffer(\"relative_coords_table\", relative_coords_table)\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(dim))\n            self.v_bias = nn.Parameter(torch.zeros(dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        # cosine attention\n        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))\n        logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=x.device))).exp()\n        attn = attn * logit_scale\n\n        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)\n        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, ' \\\n               f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        pretrained_window_size (int): Window size in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,\n            pretrained_window_size=to_2tuple(pretrained_window_size))\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        x = shortcut + self.drop_path(self.norm1(x))\n\n        # FFN\n        x = x + self.drop_path(self.norm2(self.mlp(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(2 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.reduction(x)\n        x = self.norm(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        flops += H * W * self.dim // 2\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        pretrained_window_size (int): Local window size in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 pretrained_window_size=0):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 pretrained_window_size=pretrained_window_size)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n    def _init_respostnorm(self):\n        for blk in self.blocks:\n            nn.init.constant_(blk.norm1.bias, 0)\n            nn.init.constant_(blk.norm1.weight, 0)\n            nn.init.constant_(blk.norm2.bias, 0)\n            nn.init.constant_(blk.norm2.weight, 0)\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\n@BACKBONES.register_module()\nclass SwinTransformerV2(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],\n                 multi_scale=False, upsample='deconv', **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint,\n                               pretrained_window_size=pretrained_window_sizes[i_layer])\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.multi_scale = multi_scale\n        if self.multi_scale:\n            self.scales = [1, 2, 4, 4]\n            self.upsample = nn.ModuleList()\n            features = [int(embed_dim * 2 ** i) for i in range(1, self.num_layers)] + [self.num_features]\n            self.multi_scale_fuse = nn.Conv2d(sum(features), self.num_features, 1)\n            for i in range(self.num_layers):\n                self.upsample.append(nn.Upsample(scale_factor=self.scales[i]))\n        else:\n            if upsample == 'deconv':\n                self.upsample = nn.ConvTranspose2d(self.num_features, self.num_features, 2, stride=2)\n            elif upsample == 'new_deconv':\n                self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n                                              nn.Conv2d(self.num_features, self.num_features, 3, stride=1, padding=1),\n                                              nn.BatchNorm2d(self.num_features),\n                                              nn.ReLU(inplace=True)\n                                              )\n            elif upsample == 'new_deconv2':\n                self.upsample = nn.Sequential(nn.Upsample(scale_factor=2),\n                                              nn.Conv2d(self.num_features, self.num_features, 3, stride=1, padding=1),\n                                              nn.BatchNorm2d(self.num_features),\n                                              nn.ReLU(inplace=True)\n                                              )\n            elif upsample == 'bilinear':\n                self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n            else:\n                self.upsample = nn.Identity()\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n        for bly in self.layers:\n            bly._init_respostnorm()\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {\"cpb_mlp\", \"logit_scale\", 'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        if self.multi_scale:\n            # x_2d = x.view(B, H // 4, W // 4, -1).permute(0, 3, 1, 2)  # B C H W\n            # features = [self.upsample[0](x_2d)]\n            features = []\n            for i, layer in enumerate(self.layers):\n                x = layer(x)\n                x_2d = x.view(B, H // (8 * self.scales[i]), W // (8 * self.scales[i]), -1).permute(0, 3, 1,\n                                                                                                   2)  # B C H W\n                features.append(self.upsample[i](x_2d))\n            x = torch.cat(features, dim=1)\n            x = self.multi_scale_fuse(x)\n            x = x.view(B, self.num_features, -1).permute(0, 2, 1)\n            x = self.norm(x)  # B L C\n            x = x.view(B, H // 8, W // 8, self.num_features).permute(0, 3, 1, 2)  # B C H W\n\n        else:\n            for layer in self.layers:\n                x = layer(x)\n            x = self.norm(x)  # B L C\n            x = x.view(B, H // 32, W // 32, self.num_features).permute(0, 3, 1, 2)  # B C H W\n            x = self.upsample(x)\n\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/models/backbones/swin_utils.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# Modified by Zhenda Xie\n# --------------------------------------------------------\n\nimport numpy as np\nimport torch\nfrom scipy import interpolate\n\n\ndef load_pretrained(config, model, logger):\n    checkpoint = torch.load(config, map_location='cpu')\n    checkpoint_model = checkpoint['model']\n\n    if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):\n        checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if\n                            k.startswith('encoder.')}\n        print('Detect pre-trained model, remove [encoder.] prefix.')\n    else:\n        print('Detect non-pre-trained model, pass without doing anything.')\n\n    checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)\n    msg = model.load_state_dict(checkpoint_model, strict=False)\n    print(msg)\n\n    del checkpoint\n    torch.cuda.empty_cache()\n\n\ndef remap_pretrained_keys_swin(model, checkpoint_model, logger):\n    state_dict = model.state_dict()\n\n    # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size\n    all_keys = list(checkpoint_model.keys())\n    for key in all_keys:\n        if \"relative_position_bias_table\" in key:\n            relative_position_bias_table_pretrained = checkpoint_model[key]\n            relative_position_bias_table_current = state_dict[key]\n            L1, nH1 = relative_position_bias_table_pretrained.size()\n            L2, nH2 = relative_position_bias_table_current.size()\n            if nH1 != nH2:\n                print(f\"Error in loading {key}, passing......\")\n            else:\n                if L1 != L2:\n                    print(f\"{key}: Interpolate relative_position_bias_table using geo.\")\n                    src_size = int(L1 ** 0.5)\n                    dst_size = int(L2 ** 0.5)\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    # if q > 1.090307:\n                    #     q = 1.090307\n\n                    dis = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis.append(cur)\n                        cur += q ** (i + 1)\n\n                    r_ids = [-_ for _ in reversed(dis)]\n\n                    x = r_ids + [0] + dis\n                    y = r_ids + [0] + dis\n\n                    t = dst_size // 2.0\n                    dx = np.arange(-t, t + 0.1, 1.0)\n                    dy = np.arange(-t, t + 0.1, 1.0)\n\n                    print(\"Original positions = %s\" % str(x))\n                    print(\"Target positions = %s\" % str(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(nH1):\n                        z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy()\n                        f_cubic = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(\n                            relative_position_bias_table_pretrained.device))\n\n                    new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n                    checkpoint_model[key] = new_rel_pos_bias\n\n    # delete relative_position_index since we always re-init it\n    relative_position_index_keys = [k for k in checkpoint_model.keys() if \"relative_position_index\" in k]\n    for k in relative_position_index_keys:\n        del checkpoint_model[k]\n\n    # delete relative_coords_table since we always re-init it\n    relative_coords_table_keys = [k for k in checkpoint_model.keys() if \"relative_coords_table\" in k]\n    for k in relative_coords_table_keys:\n        del checkpoint_model[k]\n\n    # re-map keys due to name change\n    rpe_mlp_keys = [k for k in checkpoint_model.keys() if \"rpe_mlp\" in k]\n    for k in rpe_mlp_keys:\n        checkpoint_model[k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint_model.pop(k)\n\n    # delete attn_mask since we always re-init it\n    attn_mask_keys = [k for k in checkpoint_model.keys() if \"attn_mask\" in k]\n    for k in attn_mask_keys:\n        del checkpoint_model[k]\n\n    return checkpoint_model\n"
  },
  {
    "path": "models/models/detectors/__init__.py",
    "content": "from .pam import PoseAnythingModel\n\n__all__ = ['PoseAnythingModel']\n"
  },
  {
    "path": "models/models/detectors/pam.py",
    "content": "import numpy as np\nimport torch\nfrom mmpose.models import builder\nfrom mmpose.models.builder import POSENETS\nfrom mmpose.models.detectors.base import BasePose\n\nfrom models.models.backbones.swin_utils import load_pretrained\n\n\n@POSENETS.register_module()\nclass PoseAnythingModel(BasePose):\n    \"\"\"Few-shot keypoint detectors.\n    Args:\n        keypoint_head (dict): Keypoint head to process feature.\n        encoder_config (dict): Config for encoder. Default: None.\n        pretrained (str): Path to the pretrained models.\n        train_cfg (dict): Config for training. Default: None.\n        test_cfg (dict): Config for testing. Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 keypoint_head,\n                 encoder_config,\n                 pretrained=False,\n                 train_cfg=None,\n                 test_cfg=None):\n        super().__init__()\n        self.backbone, self.backbone_type = self.init_backbone(pretrained, encoder_config)\n        self.keypoint_head = builder.build_head(keypoint_head)\n        self.keypoint_head.init_weights()\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.target_type = test_cfg.get('target_type',\n                                        'GaussianHeatMap')  # GaussianHeatMap\n\n    def init_backbone(self, pretrained, encoder_config):\n        if 'swin' in pretrained:\n            encoder_sample = builder.build_backbone(encoder_config)\n            if '.pth' in pretrained:\n                load_pretrained(pretrained, encoder_sample, logger=None)\n            backbone = 'swin'\n        elif 'dino' in pretrained:\n            if 'dinov2' in pretrained:\n                repo = 'facebookresearch/dinov2'\n                backbone = 'dinov2'\n            else:\n                repo = 'facebookresearch/dino:main'\n                backbone = 'dino'\n            encoder_sample = torch.hub.load(repo, pretrained)\n        elif 'resnet' in pretrained:\n            pretrained = 'torchvision://resnet50'\n            encoder_config = dict(type='ResNet', depth=50, out_indices=(3,))\n            encoder_sample = builder.build_backbone(encoder_config)\n            encoder_sample.init_weights(pretrained)\n            backbone = 'resnet50'\n        else:\n            raise NotImplementedError(f'backbone {pretrained} not supported')\n        return encoder_sample, backbone\n\n    @property\n    def with_keypoint(self):\n        \"\"\"Check if has keypoint_head.\"\"\"\n        return hasattr(self, 'keypoint_head')\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Weight initialization for model.\"\"\"\n        self.backbone.init_weights(pretrained)\n        self.encoder_query.init_weights(pretrained)\n        self.keypoint_head.init_weights()\n\n    def forward(self,\n                img_s,\n                img_q,\n                target_s=None,\n                target_weight_s=None,\n                target_q=None,\n                target_weight_q=None,\n                img_metas=None,\n                return_loss=True,\n                **kwargs):\n        \"\"\"Defines the computation performed at every call.\"\"\"\n\n        if return_loss:\n            return self.forward_train(img_s, target_s, target_weight_s, img_q,\n                                      target_q, target_weight_q, img_metas,\n                                      **kwargs)\n        else:\n            return self.forward_test(img_s, target_s, target_weight_s, img_q,\n                                     target_q, target_weight_q, img_metas,\n                                     **kwargs)\n\n    def forward_dummy(self, img_s, target_s, target_weight_s, img_q, target_q,\n                      target_weight_q, img_metas, **kwargs):\n        return self.predict(\n            img_s, target_s, target_weight_s, img_q, img_metas)\n\n    def forward_train(self,\n                      img_s,\n                      target_s,\n                      target_weight_s,\n                      img_q,\n                      target_q,\n                      target_weight_q,\n                      img_metas,\n                      **kwargs):\n\n        \"\"\"Defines the computation performed at every call when training.\"\"\"\n        bs, _, h, w = img_q.shape\n\n        output, initial_proposals, similarity_map, mask_s = self.predict(\n            img_s, target_s, target_weight_s, img_q, img_metas)\n\n        # parse the img meta to get the target keypoints\n        target_keypoints = self.parse_keypoints_from_img_meta(img_metas, output.device, keyword='query')\n        target_sizes = torch.tensor([img_q.shape[-2], img_q.shape[-1]]).unsqueeze(0).repeat(img_q.shape[0], 1, 1)\n\n        # if return loss\n        losses = dict()\n        if self.with_keypoint:\n            keypoint_losses = self.keypoint_head.get_loss(\n                output, initial_proposals, similarity_map, target_keypoints,\n                target_q, target_weight_q * mask_s, target_sizes)\n            losses.update(keypoint_losses)\n            keypoint_accuracy = self.keypoint_head.get_accuracy(output[-1],\n                                                                target_keypoints,\n                                                                target_weight_q * mask_s,\n                                                                target_sizes,\n                                                                height=h)\n            losses.update(keypoint_accuracy)\n\n        return losses\n\n    def forward_test(self,\n                     img_s,\n                     target_s,\n                     target_weight_s,\n                     img_q,\n                     target_q,\n                     target_weight_q,\n                     img_metas=None,\n                     **kwargs):\n\n        \"\"\"Defines the computation performed at every call when testing.\"\"\"\n        batch_size, _, img_height, img_width = img_q.shape\n\n        output, initial_proposals, similarity_map, _ = self.predict(img_s, target_s, target_weight_s, img_q, img_metas)\n        predicted_pose = output[-1].detach().cpu().numpy()  # [bs, num_query, 2]\n\n        result = {}\n        if self.with_keypoint:\n            keypoint_result = self.keypoint_head.decode(img_metas, predicted_pose, img_size=[img_width, img_height])\n            result.update(keypoint_result)\n\n        result.update({\n            \"points\":\n                torch.cat((initial_proposals, output.squeeze(1))).cpu().numpy()\n        })\n        result.update({\"sample_image_file\": img_metas[0]['sample_image_file']})\n\n        return result\n\n    def predict(self,\n                img_s,\n                target_s,\n                target_weight_s,\n                img_q,\n                img_metas=None):\n\n        batch_size, _, img_height, img_width = img_q.shape\n        assert [i['sample_skeleton'][0] != i['query_skeleton'] for i in img_metas]\n        skeleton = [i['sample_skeleton'][0] for i in img_metas]\n\n        feature_q, feature_s = self.extract_features(img_s, img_q)\n\n        mask_s = target_weight_s[0]\n        for target_weight in target_weight_s:\n            mask_s = mask_s * target_weight\n\n        output, initial_proposals, similarity_map = self.keypoint_head(feature_q, feature_s, target_s, mask_s, skeleton)\n\n        return output, initial_proposals, similarity_map, mask_s\n\n    def extract_features(self, img_s, img_q):\n        if self.backbone_type == 'swin':\n            feature_q = self.backbone.forward_features(img_q)  # [bs, C, h, w]\n            feature_s = [self.backbone.forward_features(img) for img in img_s]\n        elif self.backbone_type == 'dino':\n            batch_size, _, img_height, img_width = img_q.shape\n            feature_q = self.backbone.get_intermediate_layers(img_q, n=1)[0][:, 1:] \\\n                .reshape(batch_size, img_height // 8, img_width // 8, -1).permute(0, 3, 1, 2)  # [bs, 3, h, w]\n            feature_s = [self.backbone.get_intermediate_layers(img, n=1)[0][:, 1:].\n                         reshape(batch_size, img_height // 8, img_width // 8, -1).permute(0, 3, 1, 2) for img in img_s]\n        elif self.backbone_type == 'dinov2':\n            batch_size, _, img_height, img_width = img_q.shape\n            feature_q = self.backbone.get_intermediate_layers(img_q, n=1, reshape=True)[0]  # [bs, c, h, w]\n            feature_s = [self.backbone.get_intermediate_layers(img, n=1, reshape=True)[0] for img in img_s]\n        else:\n            feature_s = [self.backbone(img) for img in img_s]\n            feature_q = self.encoder_query(img_q)\n\n        return feature_q, feature_s\n\n    def parse_keypoints_from_img_meta(self, img_meta, device, keyword='query'):\n        \"\"\"Parse keypoints from the img_meta.\n\n        Args:\n            img_meta (dict): Image meta info.\n            device (torch.device): Device of the output keypoints.\n            keyword (str): 'query' or 'sample'. Default: 'query'.\n\n        Returns:\n            Tensor: Keypoints coordinates of query images.\n        \"\"\"\n\n        if keyword == 'query':\n            query_kpt = torch.stack([\n                torch.tensor(info[f'{keyword}_joints_3d']).to(device)\n                for info in img_meta\n            ], dim=0)[:, :, :2]  # [bs, num_query, 2]\n        else:\n            query_kpt = []\n            for info in img_meta:\n                if isinstance(info[f'{keyword}_joints_3d'][0], torch.Tensor):\n                    samples = torch.stack(info[f'{keyword}_joints_3d'])\n                else:\n                    samples = np.array(info[f'{keyword}_joints_3d'])\n                query_kpt.append(torch.tensor(samples).to(device)[:, :, :2])\n            query_kpt = torch.stack(query_kpt, dim=0)  # [bs, , num_samples, num_query, 2]\n        return query_kpt\n\n\n    # UNMODIFIED\n    def show_result(self,\n                    img,\n                    result,\n                    skeleton=None,\n                    kpt_score_thr=0.3,\n                    bbox_color='green',\n                    pose_kpt_color=None,\n                    pose_limb_color=None,\n                    radius=4,\n                    text_color=(255, 0, 0),\n                    thickness=1,\n                    font_scale=0.5,\n                    win_name='',\n                    show=False,\n                    wait_time=0,\n                    out_file=None):\n        \"\"\"Draw `result` over `img`.\n\n        Args:\n            img (str or Tensor): The image to be displayed.\n            result (list[dict]): The results to draw over `img`\n                (bbox_result, pose_result).\n            kpt_score_thr (float, optional): Minimum score of keypoints\n                to be shown. Default: 0.3.\n            bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.\n            pose_kpt_color (np.array[Nx3]`): Color of N keypoints.\n                If None, do not draw keypoints.\n            pose_limb_color (np.array[Mx3]): Color of M limbs.\n                If None, do not draw limbs.\n            text_color (str or tuple or :obj:`Color`): Color of texts.\n            thickness (int): Thickness of lines.\n            font_scale (float): Font scales of texts.\n            win_name (str): The window name.\n            wait_time (int): Value of waitKey param.\n                Default: 0.\n            out_file (str or None): The filename to write the image.\n                Default: None.\n\n        Returns:\n            Tensor: Visualized img, only if not `show` or `out_file`.\n        \"\"\"\n\n        img = mmcv.imread(img)\n        img = img.copy()\n        img_h, img_w, _ = img.shape\n\n        bbox_result = []\n        pose_result = []\n        for res in result:\n            bbox_result.append(res['bbox'])\n            pose_result.append(res['keypoints'])\n\n        if len(bbox_result) > 0:\n            bboxes = np.vstack(bbox_result)\n            # draw bounding boxes\n            mmcv.imshow_bboxes(\n                img,\n                bboxes,\n                colors=bbox_color,\n                top_k=-1,\n                thickness=thickness,\n                show=False,\n                win_name=win_name,\n                wait_time=wait_time,\n                out_file=None)\n\n            for person_id, kpts in enumerate(pose_result):\n                # draw each point on image\n                if pose_kpt_color is not None:\n                    assert len(pose_kpt_color) == len(kpts), (\n                        len(pose_kpt_color), len(kpts))\n                    for kid, kpt in enumerate(kpts):\n                        x_coord, y_coord, kpt_score = int(kpt[0]), int(\n                            kpt[1]), kpt[2]\n                        if kpt_score > kpt_score_thr:\n                            img_copy = img.copy()\n                            r, g, b = pose_kpt_color[kid]\n                            cv2.circle(img_copy, (int(x_coord), int(y_coord)),\n                                       radius, (int(r), int(g), int(b)), -1)\n                            transparency = max(0, min(1, kpt_score))\n                            cv2.addWeighted(\n                                img_copy,\n                                transparency,\n                                img,\n                                1 - transparency,\n                                0,\n                                dst=img)\n\n                # draw limbs\n                if skeleton is not None and pose_limb_color is not None:\n                    assert len(pose_limb_color) == len(skeleton)\n                    for sk_id, sk in enumerate(skeleton):\n                        pos1 = (int(kpts[sk[0] - 1, 0]), int(kpts[sk[0] - 1,\n                        1]))\n                        pos2 = (int(kpts[sk[1] - 1, 0]), int(kpts[sk[1] - 1,\n                        1]))\n                        if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0\n                                and pos1[1] < img_h and pos2[0] > 0\n                                and pos2[0] < img_w and pos2[1] > 0\n                                and pos2[1] < img_h\n                                and kpts[sk[0] - 1, 2] > kpt_score_thr\n                                and kpts[sk[1] - 1, 2] > kpt_score_thr):\n                            img_copy = img.copy()\n                            X = (pos1[0], pos2[0])\n                            Y = (pos1[1], pos2[1])\n                            mX = np.mean(X)\n                            mY = np.mean(Y)\n                            length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5\n                            angle = math.degrees(\n                                math.atan2(Y[0] - Y[1], X[0] - X[1]))\n                            stickwidth = 2\n                            polygon = cv2.ellipse2Poly(\n                                (int(mX), int(mY)),\n                                (int(length / 2), int(stickwidth)), int(angle),\n                                0, 360, 1)\n\n                            r, g, b = pose_limb_color[sk_id]\n                            cv2.fillConvexPoly(img_copy, polygon,\n                                               (int(r), int(g), int(b)))\n                            transparency = max(\n                                0,\n                                min(\n                                    1, 0.5 *\n                                       (kpts[sk[0] - 1, 2] + kpts[sk[1] - 1, 2])))\n                            cv2.addWeighted(\n                                img_copy,\n                                transparency,\n                                img,\n                                1 - transparency,\n                                0,\n                                dst=img)\n\n        show, wait_time = 1, 1\n        if show:\n            height, width = img.shape[:2]\n            max_ = max(height, width)\n\n            factor = min(1, 800 / max_)\n            enlarge = cv2.resize(\n                img, (0, 0),\n                fx=factor,\n                fy=factor,\n                interpolation=cv2.INTER_CUBIC)\n            imshow(enlarge, win_name, wait_time)\n\n        if out_file is not None:\n            imwrite(img, out_file)\n\n        return img"
  },
  {
    "path": "models/models/keypoint_heads/__init__.py",
    "content": "from .head import PoseHead\n\n__all__ = ['PoseHead']\n"
  },
  {
    "path": "models/models/keypoint_heads/head.py",
    "content": "from copy import deepcopy\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import (Conv2d, Linear, xavier_init)\nfrom mmcv.cnn.bricks.transformer import build_positional_encoding\nfrom mmpose.core.evaluation import keypoint_pck_accuracy\nfrom mmpose.core.post_processing import transform_preds\nfrom mmpose.models import HEADS\nfrom mmpose.models.utils.ops import resize\n\nfrom models.models.utils import build_transformer\n\n\ndef inverse_sigmoid(x, eps=1e-3):\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1 / x2)\n\n\nclass TokenDecodeMLP(nn.Module):\n    '''\n    The MLP used to predict coordinates from the support keypoints tokens.\n    '''\n\n    def __init__(self,\n                 in_channels,\n                 hidden_channels,\n                 out_channels=2,\n                 num_layers=3):\n        super(TokenDecodeMLP, self).__init__()\n        layers = []\n        for i in range(num_layers):\n            if i == 0:\n                layers.append(nn.Linear(in_channels, hidden_channels))\n                layers.append(nn.GELU())\n            else:\n                layers.append(nn.Linear(hidden_channels, hidden_channels))\n                layers.append(nn.GELU())\n        layers.append(nn.Linear(hidden_channels, out_channels))\n        self.mlp = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.mlp(x)\n\n\n@HEADS.register_module()\nclass PoseHead(nn.Module):\n    '''\n    In two stage regression A3, the proposal generator are moved into transformer.\n    All valid proposals will be added with an positional embedding to better regress the location\n    '''\n\n    def __init__(self,\n                 in_channels,\n                 transformer=None,\n                 positional_encoding=dict(\n                     type='SinePositionalEncoding',\n                     num_feats=128,\n                     normalize=True),\n                 encoder_positional_encoding=dict(\n                     type='SinePositionalEncoding',\n                     num_feats=512,\n                     normalize=True),\n                 share_kpt_branch=False,\n                 num_decoder_layer=3,\n                 with_heatmap_loss=False,\n                 with_bb_loss=False,\n                 bb_temperature=0.2,\n                 heatmap_loss_weight=2.0,\n                 support_order_dropout=-1,\n                 extra=None,\n                 train_cfg=None,\n                 test_cfg=None):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.positional_encoding = build_positional_encoding(positional_encoding)\n        self.encoder_positional_encoding = build_positional_encoding(encoder_positional_encoding)\n        self.transformer = build_transformer(transformer)\n        self.embed_dims = self.transformer.d_model\n        self.with_heatmap_loss = with_heatmap_loss\n        self.with_bb_loss = with_bb_loss\n        self.bb_temperature = bb_temperature\n        self.heatmap_loss_weight = heatmap_loss_weight\n        self.support_order_dropout = support_order_dropout\n\n        assert 'num_feats' in positional_encoding\n        num_feats = positional_encoding['num_feats']\n        assert num_feats * 2 == self.embed_dims, 'embed_dims should' \\\n                                                 f' be exactly 2 times of num_feats. Found {self.embed_dims}' \\\n                                                 f' and {num_feats}.'\n        if extra is not None and not isinstance(extra, dict):\n            raise TypeError('extra should be dict or None.')\n        \"\"\"Initialize layers of the transformer head.\"\"\"\n        self.input_proj = Conv2d(self.in_channels, self.embed_dims, kernel_size=1)\n        self.query_proj = Linear(self.in_channels, self.embed_dims)\n        # Instantiate the proposal generator and subsequent keypoint branch.\n        kpt_branch = TokenDecodeMLP(\n            in_channels=self.embed_dims, hidden_channels=self.embed_dims)\n        if share_kpt_branch:\n            self.kpt_branch = nn.ModuleList(\n                [kpt_branch for i in range(num_decoder_layer)])\n        else:\n            self.kpt_branch = nn.ModuleList(\n                [deepcopy(kpt_branch) for i in range(num_decoder_layer)])\n\n        self.train_cfg = {} if train_cfg is None else train_cfg\n        self.test_cfg = {} if test_cfg is None else test_cfg\n        self.target_type = self.test_cfg.get('target_type', 'GaussianHeatMap')\n\n    def init_weights(self):\n        for m in self.modules():\n            if hasattr(m, 'weight') and m.weight.dim() > 1:\n                xavier_init(m, distribution='uniform')\n        \"\"\"Initialize weights of the transformer head.\"\"\"\n        # The initialization for transformer is important\n        self.transformer.init_weights()\n        # initialization for input_proj & prediction head\n        for mlp in self.kpt_branch:\n            nn.init.constant_(mlp.mlp[-1].weight.data, 0)\n            nn.init.constant_(mlp.mlp[-1].bias.data, 0)\n        nn.init.xavier_uniform_(self.input_proj.weight, gain=1)\n        nn.init.constant_(self.input_proj.bias, 0)\n\n        nn.init.xavier_uniform_(self.query_proj.weight, gain=1)\n        nn.init.constant_(self.query_proj.bias, 0)\n\n    def forward(self, x, feature_s, target_s, mask_s, skeleton):\n        \"\"\"\"Forward function for a single feature level.\n\n        Args:\n            x (Tensor): Input feature from backbone's single stage, shape\n                [bs, c, h, w].\n\n        Returns:\n            all_cls_scores (Tensor): Outputs from the classification head,\n                shape [nb_dec, bs, num_query, cls_out_channels]. Note\n                cls_out_channels should includes background.\n            all_bbox_preds (Tensor): Sigmoid outputs from the regression\n                head with normalized coordinate format (cx, cy, w, h).\n                Shape [nb_dec, bs, num_query, 4].\n        \"\"\"\n        # construct binary masks which used for the transformer.\n        # NOTE following the official DETR repo, non-zero values representing\n        # ignored positions, while zero values means valid positions.\n\n        # process query image feature\n        x = self.input_proj(x)\n        bs, dim, h, w = x.shape\n\n        # Disable the support keypoint positional embedding\n        support_order_embedding = x.new_zeros((bs, self.embed_dims, 1, target_s[0].shape[1])).to(torch.bool)\n\n        # Feature map pos embedding\n        masks = x.new_zeros((x.shape[0], x.shape[2], x.shape[3])).to(torch.bool)\n        pos_embed = self.positional_encoding(masks)\n\n        # process keypoint token feature\n        query_embed_list = []\n        for i, (feature, target) in enumerate(zip(feature_s, target_s)):\n            # resize the support feature back to the heatmap sizes.\n            resized_feature = resize(\n                input=feature,\n                size=target.shape[-2:],\n                mode='bilinear',\n                align_corners=False)\n            target = target / (target.sum(dim=-1).sum(dim=-1)[:, :, None, None] + 1e-8)\n            support_keypoints = target.flatten(2) @ resized_feature.flatten(2).permute(0, 2, 1)\n            query_embed_list.append(support_keypoints)\n\n        support_keypoints = torch.mean(torch.stack(query_embed_list, dim=0), 0)\n        support_keypoints = support_keypoints * mask_s\n        support_keypoints = self.query_proj(support_keypoints)\n        masks_query = (~mask_s.to(torch.bool)).squeeze(-1)  # True indicating this query matched no actual joints.\n\n        # outs_dec: [nb_dec, bs, num_query, c]\n        # memory: [bs, c, h, w]\n        # x = Query image feature, support_keypoints = Support keypoint feature\n        outs_dec, initial_proposals, out_points, similarity_map = self.transformer(x,\n                                                                                   masks,\n                                                                                   support_keypoints,\n                                                                                   pos_embed,\n                                                                                   support_order_embedding,\n                                                                                   masks_query,\n                                                                                   self.positional_encoding,\n                                                                                   self.kpt_branch,\n                                                                                   skeleton)\n\n        output_kpts = []\n        for idx in range(outs_dec.shape[0]):\n            layer_delta_unsig = self.kpt_branch[idx](outs_dec[idx])\n            layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(\n                out_points[idx])\n            output_kpts.append(layer_outputs_unsig.sigmoid())\n\n        return torch.stack(output_kpts, dim=0), initial_proposals, similarity_map\n\n    def get_loss(self, output, initial_proposals, similarity_map, target, target_heatmap, target_weight, target_sizes):\n        # Calculate top-down keypoint loss.\n        losses = dict()\n        # denormalize the predicted coordinates.\n        num_dec_layer, bs, nq = output.shape[:3]\n        target_sizes = target_sizes.to(output.device)  # [bs, 1, 2]\n        target = target / target_sizes\n        target = target[None, :, :, :].repeat(num_dec_layer, 1, 1, 1)\n\n        # set the weight for unset query point to be zero\n        normalizer = target_weight.squeeze(dim=-1).sum(dim=-1)  # [bs, ]\n        normalizer[normalizer == 0] = 1\n\n        # compute the heatmap loss\n        if self.with_heatmap_loss:\n            losses['heatmap_loss'] = self.heatmap_loss(\n                similarity_map, target_heatmap, target_weight,\n                normalizer) * self.heatmap_loss_weight\n\n        # compute l1 loss for inital_proposals\n        proposal_l1_loss = F.l1_loss(\n            initial_proposals, target[0], reduction=\"none\")\n        proposal_l1_loss = proposal_l1_loss.sum(\n            dim=-1, keepdim=False) * target_weight.squeeze(dim=-1)\n        proposal_l1_loss = proposal_l1_loss.sum(\n            dim=-1, keepdim=False) / normalizer  # [bs, ]\n        losses['proposal_loss'] = proposal_l1_loss.sum() / bs\n\n        # compute l1 loss for each layer\n        for idx in range(num_dec_layer):\n            layer_output, layer_target = output[idx], target[idx]\n            l1_loss = F.l1_loss(\n                layer_output, layer_target, reduction=\"none\")  # [bs, query, 2]\n            l1_loss = l1_loss.sum(\n                dim=-1, keepdim=False) * target_weight.squeeze(\n                dim=-1)  # [bs, query]\n            # normalize the loss for each sample with the number of visible joints\n            l1_loss = l1_loss.sum(dim=-1, keepdim=False) / normalizer  # [bs, ]\n            losses['l1_loss' + '_layer' + str(idx)] = l1_loss.sum() / bs\n\n        return losses\n\n    def get_max_coords(self, heatmap, heatmap_size=64):\n        B, C, H, W = heatmap.shape\n        heatmap = heatmap.view(B, C, -1)\n        max_cor = heatmap.argmax(dim=2)\n        row, col = torch.floor(max_cor / heatmap_size), max_cor % heatmap_size\n        support_joints = torch.cat((row.unsqueeze(-1), col.unsqueeze(-1)), dim=-1)\n        return support_joints\n\n    def heatmap_loss(self, similarity_map, target_heatmap, target_weight,\n                     normalizer):\n        # similarity_map: [bs, num_query, h, w]\n        # target_heatmap: [bs, num_query, sh, sw]\n        # target_weight: [bs, num_query, 1]\n\n        # preprocess the similarity_map\n        h, w = similarity_map.shape[-2:]\n        # similarity_map = torch.clamp(similarity_map, 0.0, None)\n        similarity_map = similarity_map.sigmoid()\n\n        target_heatmap = F.interpolate(\n            target_heatmap, size=(h, w), mode='bilinear')\n        target_heatmap = (target_heatmap /\n                          (target_heatmap.max(dim=-1)[0].max(dim=-1)[0] + 1e-10)[:, :, None,\n                          None])  # make sure sum of each query is 1\n\n        l2_loss = F.mse_loss(\n            similarity_map, target_heatmap, reduction=\"none\")  # bs, nq, h, w\n        l2_loss = l2_loss * target_weight[:, :, :, None]  # bs, nq, h, w\n        l2_loss = l2_loss.flatten(2, 3).sum(-1) / (h * w)  # bs, nq\n        l2_loss = l2_loss.sum(-1) / normalizer  # bs,\n\n        return l2_loss.mean()\n\n    def get_accuracy(self, output, target, target_weight, target_sizes, height=256):\n        \"\"\"Calculate accuracy for top-down keypoint loss.\n\n        Args:\n            output (torch.Tensor[NxKx2]): estimated keypoints in ABSOLUTE coordinates.\n            target (torch.Tensor[NxKx2]): gt keypoints in ABSOLUTE coordinates.\n            target_weight (torch.Tensor[NxKx1]): Weights across different joint types.\n            target_sizes (torch.Tensor[Nx2): shapes of the image.\n        \"\"\"\n        # NOTE: In POMNet, PCK is estimated on 1/8 resolution, which is slightly different here.\n\n        accuracy = dict()\n        output = output * float(height)\n        output, target, target_weight, target_sizes = (\n            output.detach().cpu().numpy(), target.detach().cpu().numpy(),\n            target_weight.squeeze(-1).long().detach().cpu().numpy(),\n            target_sizes.squeeze(1).detach().cpu().numpy())\n\n        _, avg_acc, _ = keypoint_pck_accuracy(\n            output,\n            target,\n            target_weight.astype(np.bool8),\n            thr=0.2,\n            normalize=target_sizes)\n        accuracy['acc_pose'] = float(avg_acc)\n\n        return accuracy\n\n    def decode(self, img_metas, output, img_size, **kwargs):\n        \"\"\"Decode the predicted keypoints from prediction.\n\n        Args:\n            img_metas (list(dict)): Information about data augmentation\n                By default this includes:\n                - \"image_file: path to the image file\n                - \"center\": center of the bbox\n                - \"scale\": scale of the bbox\n                - \"rotation\": rotation of the bbox\n                - \"bbox_score\": score of bbox\n            output (np.ndarray[N, K, H, W]): model predicted heatmaps.\n        \"\"\"\n        batch_size = len(img_metas)\n        W, H = img_size\n        output = output * np.array([W, H])[None, None, :]  # [bs, query, 2], coordinates with recovered shapes.\n\n        if 'bbox_id' or 'query_bbox_id' in img_metas[0]:\n            bbox_ids = []\n        else:\n            bbox_ids = None\n\n        c = np.zeros((batch_size, 2), dtype=np.float32)\n        s = np.zeros((batch_size, 2), dtype=np.float32)\n        image_paths = []\n        score = np.ones(batch_size)\n        for i in range(batch_size):\n            c[i, :] = img_metas[i]['query_center']\n            s[i, :] = img_metas[i]['query_scale']\n            image_paths.append(img_metas[i]['query_image_file'])\n\n            if 'query_bbox_score' in img_metas[i]:\n                score[i] = np.array(\n                    img_metas[i]['query_bbox_score']).reshape(-1)\n            if 'bbox_id' in img_metas[i]:\n                bbox_ids.append(img_metas[i]['bbox_id'])\n            elif 'query_bbox_id' in img_metas[i]:\n                bbox_ids.append(img_metas[i]['query_bbox_id'])\n\n        preds = np.zeros(output.shape)\n        for idx in range(output.shape[0]):\n            preds[i] = transform_preds(\n                output[i],\n                c[i],\n                s[i], [W, H],\n                use_udp=self.test_cfg.get('use_udp', False))\n\n        all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)\n        all_boxes = np.zeros((batch_size, 6), dtype=np.float32)\n        all_preds[:, :, 0:2] = preds[:, :, 0:2]\n        all_preds[:, :, 2:3] = 1.0\n        all_boxes[:, 0:2] = c[:, 0:2]\n        all_boxes[:, 2:4] = s[:, 0:2]\n        all_boxes[:, 4] = np.prod(s * 200.0, axis=1)\n        all_boxes[:, 5] = score\n\n        result = {}\n\n        result['preds'] = all_preds\n        result['boxes'] = all_boxes\n        result['image_paths'] = image_paths\n        result['bbox_ids'] = bbox_ids\n\n        return result\n"
  },
  {
    "path": "models/models/utils/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import build_linear_layer, build_transformer, build_backbone\nfrom .encoder_decoder import EncoderDecoder\nfrom .positional_encoding import (LearnedPositionalEncoding,\n                                  SinePositionalEncoding)\nfrom .transformer import (DetrTransformerDecoderLayer, DetrTransformerDecoder,\n                          DetrTransformerEncoder, DynamicConv)\n\n__all__ = [\n    'build_transformer', 'build_backbone', 'build_linear_layer', 'DetrTransformerDecoderLayer',\n    'DetrTransformerDecoder', 'DetrTransformerEncoder',\n    'LearnedPositionalEncoding', 'SinePositionalEncoding',\n    'EncoderDecoder',\n]\n"
  },
  {
    "path": "models/models/utils/builder.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch.nn as nn\nfrom mmcv.utils import Registry, build_from_cfg\n\nTRANSFORMER = Registry('Transformer')\nBACKBONES = Registry('BACKBONES')\nLINEAR_LAYERS = Registry('linear layers')\n\n\ndef build_backbone(cfg, default_args=None):\n    \"\"\"Build backbone.\"\"\"\n    return build_from_cfg(cfg, BACKBONES, default_args)\n\n\ndef build_transformer(cfg, default_args=None):\n    \"\"\"Builder for Transformer.\"\"\"\n    return build_from_cfg(cfg, TRANSFORMER, default_args)\n\n\nLINEAR_LAYERS.register_module('Linear', module=nn.Linear)\n\n\ndef build_linear_layer(cfg, *args, **kwargs):\n    \"\"\"Build linear layer.\n    Args:\n        cfg (None or dict): The linear layer config, which should contain:\n            - type (str): Layer type.\n            - layer args: Args needed to instantiate an linear layer.\n        args (argument list): Arguments passed to the `__init__`\n            method of the corresponding linear layer.\n        kwargs (keyword arguments): Keyword arguments passed to the `__init__`\n            method of the corresponding linear layer.\n    Returns:\n        nn.Module: Created linear layer.\n    \"\"\"\n    if cfg is None:\n        cfg_ = dict(type='Linear')\n    else:\n        if not isinstance(cfg, dict):\n            raise TypeError('cfg must be a dict')\n        if 'type' not in cfg:\n            raise KeyError('the cfg dict must contain the key \"type\"')\n        cfg_ = cfg.copy()\n\n    layer_type = cfg_.pop('type')\n    if layer_type not in LINEAR_LAYERS:\n        raise KeyError(f'Unrecognized linear type {layer_type}')\n    else:\n        linear_layer = LINEAR_LAYERS.get(layer_type)\n\n    layer = linear_layer(*args, **kwargs, **cfg_)\n\n    return layer\n"
  },
  {
    "path": "models/models/utils/encoder_decoder.py",
    "content": "import copy\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import xavier_init\nfrom torch import Tensor\n\nfrom models.models.utils.builder import TRANSFORMER\n\n\ndef inverse_sigmoid(x, eps=1e-3):\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1 / x2)\n\n\nclass MLP(nn.Module):\n    \"\"\" Very simple multi-layer perceptron (also called FFN)\"\"\"\n\n    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):\n        super().__init__()\n        self.num_layers = num_layers\n        h = [hidden_dim] * (num_layers - 1)\n        self.layers = nn.ModuleList(\n            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n\n    def forward(self, x):\n        for i, layer in enumerate(self.layers):\n            x = F.gelu(layer(x)) if i < self.num_layers - 1 else layer(x)\n        return x\n\n\nclass ProposalGenerator(nn.Module):\n\n    def __init__(self, hidden_dim, proj_dim, dynamic_proj_dim):\n        super().__init__()\n        self.support_proj = nn.Linear(hidden_dim, proj_dim)\n        self.query_proj = nn.Linear(hidden_dim, proj_dim)\n        self.dynamic_proj = nn.Sequential(\n            nn.Linear(hidden_dim, dynamic_proj_dim),\n            nn.ReLU(),\n            nn.Linear(dynamic_proj_dim, hidden_dim))\n        self.dynamic_act = nn.Tanh()\n\n    def forward(self, query_feat, support_feat, spatial_shape):\n        \"\"\"\n        Args:\n            support_feat: [query, bs, c]\n            query_feat: [hw, bs, c]\n            spatial_shape: h, w\n        \"\"\"\n        device = query_feat.device\n        _, bs, c = query_feat.shape\n        h, w = spatial_shape\n        side_normalizer = torch.tensor([w, h]).to(query_feat.device)[None, None,\n                          :]  # [bs, query, 2], Normalize the coord to [0,1]\n\n        query_feat = query_feat.transpose(0, 1)\n        support_feat = support_feat.transpose(0, 1)\n        nq = support_feat.shape[1]\n\n        fs_proj = self.support_proj(support_feat)  # [bs, query, c]\n        fq_proj = self.query_proj(query_feat)  # [bs, hw, c]\n        pattern_attention = self.dynamic_act(self.dynamic_proj(fs_proj))  # [bs, query, c]\n\n        fs_feat = (pattern_attention + 1) * fs_proj  # [bs, query, c]\n        similarity = torch.bmm(fq_proj, fs_feat.transpose(1, 2))  # [bs, hw, query]\n        similarity = similarity.transpose(1, 2).reshape(bs, nq, h, w)\n        grid_y, grid_x = torch.meshgrid(\n            torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=device),  # (h, w)\n            torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=device))\n\n        # compute softmax and sum up\n        coord_grid = torch.stack([grid_x, grid_y],\n                                 dim=0).unsqueeze(0).unsqueeze(0).repeat(bs, nq, 1, 1, 1)  # [bs, query, 2, h, w]\n        coord_grid = coord_grid.permute(0, 1, 3, 4, 2)  # [bs, query, h, w, 2]\n        similarity_softmax = similarity.flatten(2, 3).softmax(dim=-1)  # [bs, query, hw]\n        similarity_coord_grid = similarity_softmax[:, :, :, None] * coord_grid.flatten(2, 3)\n        proposal_for_loss = similarity_coord_grid.sum(dim=2, keepdim=False)  # [bs, query, 2]\n        proposal_for_loss = proposal_for_loss / side_normalizer\n\n        max_pos = torch.argmax(similarity.reshape(bs, nq, -1), dim=-1, keepdim=True)  # (bs, nq, 1)\n        max_mask = F.one_hot(max_pos, num_classes=w * h)  # (bs, nq, 1, w*h)\n        max_mask = max_mask.reshape(bs, nq, w, h).type(torch.float)  # (bs, nq, w, h)\n        local_max_mask = F.max_pool2d(\n            input=max_mask, kernel_size=3, stride=1,\n            padding=1).reshape(bs, nq, w * h, 1)  # (bs, nq, w*h, 1)\n        '''\n        proposal = (similarity_coord_grid * local_max_mask).sum(\n            dim=2, keepdim=False) / torch.count_nonzero(\n                local_max_mask, dim=2)\n        '''\n        # first, extract the local probability map with the mask\n        local_similarity_softmax = similarity_softmax[:, :, :, None] * local_max_mask  # (bs, nq, w*h, 1)\n\n        # then, re-normalize the local probability map\n        local_similarity_softmax = local_similarity_softmax / (\n                local_similarity_softmax.sum(dim=-2, keepdim=True) + 1e-10\n        )  # [bs, nq, w*h, 1]\n\n        # point-wise mulplication of local probability map and coord grid\n        proposals = local_similarity_softmax * coord_grid.flatten(2, 3)  # [bs, nq, w*h, 2]\n\n        # sum the mulplication to obtain the final coord proposals\n        proposals = proposals.sum(dim=2) / side_normalizer  # [bs, nq, 2]\n\n        return proposal_for_loss, similarity, proposals\n\n\n@TRANSFORMER.register_module()\nclass EncoderDecoder(nn.Module):\n\n    def __init__(self,\n                 d_model=256,\n                 nhead=8,\n                 num_encoder_layers=3,\n                 num_decoder_layers=3,\n                 graph_decoder=None,\n                 dim_feedforward=2048,\n                 dropout=0.1,\n                 activation=\"relu\",\n                 normalize_before=False,\n                 similarity_proj_dim=256,\n                 dynamic_proj_dim=128,\n                 return_intermediate_dec=True,\n                 look_twice=False,\n                 detach_support_feat=False):\n        super().__init__()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,\n                                                activation, normalize_before)\n        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n\n        decoder_layer = GraphTransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,\n                                                     activation, normalize_before, graph_decoder)\n        decoder_norm = nn.LayerNorm(d_model)\n        self.decoder = GraphTransformerDecoder(d_model, decoder_layer, num_decoder_layers, decoder_norm,\n                                               return_intermediate=return_intermediate_dec,\n                                               look_twice=look_twice, detach_support_feat=detach_support_feat)\n\n        self.proposal_generator = ProposalGenerator(\n            hidden_dim=d_model,\n            proj_dim=similarity_proj_dim,\n            dynamic_proj_dim=dynamic_proj_dim)\n\n    def init_weights(self):\n        # follow the official DETR to init parameters\n        for m in self.modules():\n            if hasattr(m, 'weight') and m.weight.dim() > 1:\n                xavier_init(m, distribution='uniform')\n\n    def forward(self, src, mask, support_embed, pos_embed, support_order_embed,\n                query_padding_mask, position_embedding, kpt_branch, skeleton, return_attn_map=False):\n\n        bs, c, h, w = src.shape\n\n        src = src.flatten(2).permute(2, 0, 1)\n        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)\n        support_order_embed = support_order_embed.flatten(2).permute(2, 0, 1)\n        pos_embed = torch.cat((pos_embed, support_order_embed))\n        query_embed = support_embed.transpose(0, 1)\n        mask = mask.flatten(1)\n\n        query_embed, refined_support_embed = self.encoder(\n            src,\n            query_embed,\n            src_key_padding_mask=mask,\n            query_key_padding_mask=query_padding_mask,\n            pos=pos_embed)\n\n        # Generate initial proposals and corresponding positional embedding.\n        initial_proposals_for_loss, similarity_map, initial_proposals = self.proposal_generator(\n            query_embed, refined_support_embed, spatial_shape=[h, w])\n        initial_position_embedding = position_embedding.forward_coordinates(initial_proposals)\n\n        outs_dec, out_points, attn_maps = self.decoder(\n            refined_support_embed,\n            query_embed,\n            memory_key_padding_mask=mask,\n            pos=pos_embed,\n            query_pos=initial_position_embedding,\n            tgt_key_padding_mask=query_padding_mask,\n            position_embedding=position_embedding,\n            initial_proposals=initial_proposals,\n            kpt_branch=kpt_branch,\n            skeleton=skeleton,\n            return_attn_map=return_attn_map)\n\n        return outs_dec.transpose(1, 2), initial_proposals_for_loss, out_points, similarity_map\n\n\nclass GraphTransformerDecoder(nn.Module):\n\n    def __init__(self,\n                 d_model,\n                 decoder_layer,\n                 num_layers,\n                 norm=None,\n                 return_intermediate=False,\n                 look_twice=False,\n                 detach_support_feat=False):\n        super().__init__()\n        self.layers = _get_clones(decoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n        self.return_intermediate = return_intermediate\n        self.ref_point_head = MLP(d_model, d_model, d_model, num_layers=2)\n        self.look_twice = look_twice\n        self.detach_support_feat = detach_support_feat\n\n    def forward(self,\n                support_feat,\n                query_feat,\n                tgt_mask=None,\n                memory_mask=None,\n                tgt_key_padding_mask=None,\n                memory_key_padding_mask=None,\n                pos=None,\n                query_pos=None,\n                position_embedding=None,\n                initial_proposals=None,\n                kpt_branch=None,\n                skeleton=None,\n                return_attn_map=False):\n        \"\"\"\n        position_embedding: Class used to compute positional embedding\n        inital_proposals: [bs, nq, 2], normalized coordinates of inital proposals\n        kpt_branch: MLP used to predict the offsets for each query.\n        \"\"\"\n\n        refined_support_feat = support_feat\n        intermediate = []\n        attn_maps = []\n        bi = initial_proposals.detach()\n        bi_tag = initial_proposals.detach()\n        query_points = [initial_proposals.detach()]\n\n        tgt_key_padding_mask_remove_all_true = tgt_key_padding_mask.clone().to(tgt_key_padding_mask.device)\n        tgt_key_padding_mask_remove_all_true[tgt_key_padding_mask.logical_not().sum(dim=-1) == 0, 0] = False\n\n        for layer_idx, layer in enumerate(self.layers):\n            if layer_idx == 0:  # use positional embedding form inital proposals\n                query_pos_embed = query_pos.transpose(0, 1)\n            else:\n                # recalculate the positional embedding\n                query_pos_embed = position_embedding.forward_coordinates(bi)\n                query_pos_embed = query_pos_embed.transpose(0, 1)\n            query_pos_embed = self.ref_point_head(query_pos_embed)\n\n            if self.detach_support_feat:\n                refined_support_feat = refined_support_feat.detach()\n\n            refined_support_feat, attn_map = layer(\n                refined_support_feat,\n                query_feat,\n                tgt_mask=tgt_mask,\n                memory_mask=memory_mask,\n                tgt_key_padding_mask=tgt_key_padding_mask_remove_all_true,\n                memory_key_padding_mask=memory_key_padding_mask,\n                pos=pos,\n                query_pos=query_pos_embed,\n                skeleton=skeleton)\n\n            if self.return_intermediate:\n                intermediate.append(self.norm(refined_support_feat))\n\n            if return_attn_map:\n                attn_maps.append(attn_map)\n\n            # update the query coordinates\n            delta_bi = kpt_branch[layer_idx](refined_support_feat.transpose(0, 1))\n\n            # Prediction loss\n            if self.look_twice:\n                bi_pred = self.update(bi_tag, delta_bi)\n                bi_tag = self.update(bi, delta_bi)\n            else:\n                bi_tag = self.update(bi, delta_bi)\n                bi_pred = bi_tag\n\n            bi = bi_tag.detach()\n            query_points.append(bi_pred)\n\n        if self.norm is not None:\n            refined_support_feat = self.norm(refined_support_feat)\n            if self.return_intermediate:\n                intermediate.pop()\n                intermediate.append(refined_support_feat)\n\n        if self.return_intermediate:\n            return torch.stack(intermediate), query_points, attn_maps\n\n        return refined_support_feat.unsqueeze(0), query_points, attn_maps\n\n    def update(self, query_coordinates, delta_unsig):\n        query_coordinates_unsigmoid = inverse_sigmoid(query_coordinates)\n        new_query_coordinates = query_coordinates_unsigmoid + delta_unsig\n        new_query_coordinates = new_query_coordinates.sigmoid()\n        return new_query_coordinates\n\n\nclass GraphTransformerDecoderLayer(nn.Module):\n\n    def __init__(self,\n                 d_model,\n                 nhead,\n                 dim_feedforward=2048,\n                 dropout=0.1,\n                 activation=\"relu\",\n                 normalize_before=False,\n                 graph_decoder=None):\n\n        super().__init__()\n        self.graph_decoder = graph_decoder\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        self.multihead_attn = nn.MultiheadAttention(\n            d_model * 2, nhead, dropout=dropout, vdim=d_model)\n        self.choker = nn.Linear(in_features=2 * d_model, out_features=d_model)\n        # Implementation of Feedforward model\n        if self.graph_decoder is None:\n            self.ffn1 = nn.Linear(d_model, dim_feedforward)\n            self.ffn2 = nn.Linear(dim_feedforward, d_model)\n        elif self.graph_decoder == 'pre':\n            self.ffn1 = GCNLayer(d_model, dim_feedforward, batch_first=False)\n            self.ffn2 = nn.Linear(dim_feedforward, d_model)\n        elif self.graph_decoder == 'post':\n            self.ffn1 = nn.Linear(d_model, dim_feedforward)\n            self.ffn2 = GCNLayer(dim_feedforward, d_model, batch_first=False)\n        else:\n            self.ffn1 = GCNLayer(d_model, dim_feedforward, batch_first=False)\n            self.ffn2 = GCNLayer(dim_feedforward, d_model, batch_first=False)\n\n        self.dropout = nn.Dropout(dropout)\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.norm3 = nn.LayerNorm(d_model)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.dropout3 = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward(self,\n                refined_support_feat,\n                refined_query_feat,\n                tgt_mask: Optional[Tensor] = None,\n                memory_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                memory_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None,\n                query_pos: Optional[Tensor] = None,\n                skeleton: Optional[list] = None):\n\n        q = k = self.with_pos_embed(refined_support_feat, query_pos + pos[refined_query_feat.shape[0]:])\n        tgt2 = self.self_attn(\n            q,\n            k,\n            value=refined_support_feat,\n            attn_mask=tgt_mask,\n            key_padding_mask=tgt_key_padding_mask)[0]\n\n        refined_support_feat = refined_support_feat + self.dropout1(tgt2)\n        refined_support_feat = self.norm1(refined_support_feat)\n\n        # concatenate the positional embedding with the content feature, instead of direct addition\n        cross_attn_q = torch.cat((refined_support_feat, query_pos + pos[refined_query_feat.shape[0]:]), dim=-1)\n        cross_attn_k = torch.cat((refined_query_feat, pos[:refined_query_feat.shape[0]]), dim=-1)\n\n        tgt2, attn_map = self.multihead_attn(\n            query=cross_attn_q,\n            key=cross_attn_k,\n            value=refined_query_feat,\n            attn_mask=memory_mask,\n            key_padding_mask=memory_key_padding_mask)\n\n        refined_support_feat = refined_support_feat + self.dropout2(self.choker(tgt2))\n        refined_support_feat = self.norm2(refined_support_feat)\n        if self.graph_decoder is not None:\n            num_pts, b, c = refined_support_feat.shape\n            adj = adj_from_skeleton(num_pts=num_pts,\n                                    skeleton=skeleton,\n                                    mask=tgt_key_padding_mask,\n                                    device=refined_support_feat.device)\n            if self.graph_decoder == 'pre':\n                tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat, adj))))\n            elif self.graph_decoder == 'post':\n                tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat))), adj)\n            else:\n                tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat, adj))), adj)\n        else:\n            tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat))))\n        refined_support_feat = refined_support_feat + self.dropout3(tgt2)\n        refined_support_feat = self.norm3(refined_support_feat)\n\n        return refined_support_feat, attn_map\n\n\nclass TransformerEncoder(nn.Module):\n\n    def __init__(self, encoder_layer, num_layers, norm=None):\n        super().__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n\n    def forward(self,\n                src,\n                query,\n                mask: Optional[Tensor] = None,\n                src_key_padding_mask: Optional[Tensor] = None,\n                query_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None):\n        # src: [hw, bs, c]\n        # query: [num_query, bs, c]\n        # mask: None by default\n        # src_key_padding_mask: [bs, hw]\n        # query_key_padding_mask: [bs, nq]\n        # pos: [hw, bs, c]\n\n        # organize the input\n        # implement the attention mask to mask out the useless points\n        n, bs, c = src.shape\n        src_cat = torch.cat((src, query), dim=0)  # [hw + nq, bs, c]\n        mask_cat = torch.cat((src_key_padding_mask, query_key_padding_mask),\n                             dim=1)  # [bs, hw+nq]\n        output = src_cat\n\n        for layer in self.layers:\n            output = layer(\n                output,\n                query_length=n,\n                src_mask=mask,\n                src_key_padding_mask=mask_cat,\n                pos=pos)\n\n        if self.norm is not None:\n            output = self.norm(output)\n\n        # resplit the output into src and query\n        refined_query = output[n:, :, :]  # [nq, bs, c]\n        output = output[:n, :, :]  # [n, bs, c]\n\n        return output, refined_query\n\n\nclass TransformerEncoderLayer(nn.Module):\n\n    def __init__(self,\n                 d_model,\n                 nhead,\n                 dim_feedforward=2048,\n                 dropout=0.1,\n                 activation=\"relu\",\n                 normalize_before=False):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward(self,\n                src,\n                query_length,\n                src_mask: Optional[Tensor] = None,\n                src_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None):\n        src = self.with_pos_embed(src, pos)\n        q = k = src\n        # NOTE: compared with original implementation, we add positional embedding into the VALUE.\n        src2 = self.self_attn(\n            q,\n            k,\n            value=src,\n            attn_mask=src_mask,\n            key_padding_mask=src_key_padding_mask)[0]\n        src = src + self.dropout1(src2)\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n        src = src + self.dropout2(src2)\n        src = self.norm2(src)\n        return src\n\n\ndef adj_from_skeleton(num_pts, skeleton, mask, device='cuda'):\n    adj_mx = torch.empty(0, device=device)\n    batch_size = len(skeleton)\n    for b in range(batch_size):\n        edges = torch.tensor(skeleton[b])\n        adj = torch.zeros(num_pts, num_pts, device=device)\n        adj[edges[:, 0], edges[:, 1]] = 1\n        adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0)\n    trans_adj_mx = torch.transpose(adj_mx, 1, 2)\n    cond = (trans_adj_mx > adj_mx).float()\n    adj = adj_mx + trans_adj_mx * cond - adj_mx * cond\n    adj = adj * ~mask[..., None] * ~mask[:, None]\n    adj = torch.nan_to_num(adj / adj.sum(dim=-1, keepdim=True))\n    adj = torch.stack((torch.diag_embed(~mask), adj), dim=1)\n    return adj\n\n\nclass GCNLayer(nn.Module):\n    def __init__(self,\n                 in_features,\n                 out_features,\n                 kernel_size=2,\n                 use_bias=True,\n                 activation=nn.ReLU(inplace=True),\n                 batch_first=True):\n        super(GCNLayer, self).__init__()\n        self.conv = nn.Conv1d(in_features, out_features * kernel_size, kernel_size=1,\n                              padding=0, stride=1, dilation=1, bias=use_bias)\n        self.kernel_size = kernel_size\n        self.activation = activation\n        self.batch_first = batch_first\n\n    def forward(self, x, adj):\n        assert adj.size(1) == self.kernel_size\n        if not self.batch_first:\n            x = x.permute(1, 2, 0)\n        else:\n            x = x.transpose(1, 2)\n        x = self.conv(x)\n        b, kc, v = x.size()\n        x = x.view(b, self.kernel_size, kc // self.kernel_size, v)\n        x = torch.einsum('bkcv,bkvw->bcw', (x, adj))\n        if self.activation is not None:\n            x = self.activation(x)\n        if not self.batch_first:\n            x = x.permute(2, 0, 1)\n        else:\n            x = x.transpose(1, 2)\n        return x\n\n\ndef _get_clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n\n\ndef clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])\n"
  },
  {
    "path": "models/models/utils/positional_encoding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING\nfrom mmcv.runner import BaseModule\n\n\n# TODO: add an SinePositionalEncoding for coordinates input\n\n@POSITIONAL_ENCODING.register_module()\nclass SinePositionalEncoding(BaseModule):\n    \"\"\"Position encoding with sine and cosine functions.\n\n    See `End-to-End Object Detection with Transformers\n    <https://arxiv.org/pdf/2005.12872>`_ for details.\n\n    Args:\n        num_feats (int): The feature dimension for each position\n            along x-axis or y-axis. Note the final returned dimension\n            for each position is 2 times of this value.\n        temperature (int, optional): The temperature used for scaling\n            the position embedding. Defaults to 10000.\n        normalize (bool, optional): Whether to normalize the position\n            embedding. Defaults to False.\n        scale (float, optional): A scale factor that scales the position\n            embedding. The scale will be used only when `normalize` is True.\n            Defaults to 2*pi.\n        eps (float, optional): A value added to the denominator for\n            numerical stability. Defaults to 1e-6.\n        offset (float): offset add to embed when do the normalization.\n            Defaults to 0.\n        init_cfg (dict or list[dict], optional): Initialization config dict.\n            Default: None\n    \"\"\"\n\n    def __init__(self,\n                 num_feats,\n                 temperature=10000,\n                 normalize=False,\n                 scale=2 * math.pi,\n                 eps=1e-6,\n                 offset=0.,\n                 init_cfg=None):\n        super(SinePositionalEncoding, self).__init__(init_cfg)\n        if normalize:\n            assert isinstance(scale, (float, int)), 'when normalize is set,' \\\n                                                    'scale should be provided and in float or int type, ' \\\n                                                    f'found {type(scale)}'\n        self.num_feats = num_feats\n        self.temperature = temperature\n        self.normalize = normalize\n        self.scale = scale\n        self.eps = eps\n        self.offset = offset\n\n    def forward(self, mask):\n        \"\"\"Forward function for `SinePositionalEncoding`.\n\n        Args:\n            mask (Tensor): ByteTensor mask. Non-zero values representing\n                ignored positions, while zero values means valid positions\n                for this image. Shape [bs, h, w].\n\n        Returns:\n            pos (Tensor): Returned position embedding with shape\n                [bs, num_feats*2, h, w].\n        \"\"\"\n        # For convenience of exporting to ONNX, it's required to convert\n        # `masks` from bool to int.\n        mask = mask.to(torch.int)\n        not_mask = 1 - mask  # logical_not\n        y_embed = not_mask.cumsum(1, dtype=torch.float32)  # [bs, h, w], recording the y coordinate ot each pixel\n        x_embed = not_mask.cumsum(2, dtype=torch.float32)\n        if self.normalize:  # default True\n            y_embed = (y_embed + self.offset) / \\\n                      (y_embed[:, -1:, :] + self.eps) * self.scale\n            x_embed = (x_embed + self.offset) / \\\n                      (x_embed[:, :, -1:] + self.eps) * self.scale\n        dim_t = torch.arange(\n            self.num_feats, dtype=torch.float32, device=mask.device)\n        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)\n        pos_x = x_embed[:, :, :, None] / dim_t  # [bs, h, w, num_feats]\n        pos_y = y_embed[:, :, :, None] / dim_t\n        # use `view` instead of `flatten` for dynamically exporting to ONNX\n        B, H, W = mask.size()\n        pos_x = torch.stack(\n            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),\n            dim=4).view(B, H, W, -1)  # [bs, h, w, num_feats]\n        pos_y = torch.stack(\n            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),\n            dim=4).view(B, H, W, -1)\n        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)\n        return pos\n\n    def forward_coordinates(self, coord):\n        \"\"\"\n        Forward funtion for normalized coordinates input with the shape of [bs, kpt, 2]\n        return:\n            pos (Tensor): position embedding with the shape of [bs, kpt, num_feats*2]\n        \"\"\"\n        x_embed, y_embed = coord[:, :, 0], coord[:, :, 1]  # [bs, kpt]\n        x_embed = x_embed * self.scale  # [bs, kpt]\n        y_embed = y_embed * self.scale\n\n        dim_t = torch.arange(\n            self.num_feats, dtype=torch.float32, device=coord.device)\n        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)\n\n        pos_x = x_embed[:, :, None] / dim_t  # [bs, kpt, num_feats]\n        pos_y = y_embed[:, :, None] / dim_t  # [bs, kpt, num_feats]\n        bs, kpt, _ = pos_x.shape\n\n        pos_x = torch.stack(\n            (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),\n            dim=3).view(bs, kpt, -1)  # [bs, kpt, num_feats]\n        pos_y = torch.stack(\n            (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),\n            dim=3).view(bs, kpt, -1)  # [bs, kpt, num_feats]\n        pos = torch.cat((pos_y, pos_x), dim=2)  # [bs, kpt, num_feats * 2]\n\n        return pos\n\n    def __repr__(self):\n        \"\"\"str: a string that describes the module\"\"\"\n        repr_str = self.__class__.__name__\n        repr_str += f'(num_feats={self.num_feats}, '\n        repr_str += f'temperature={self.temperature}, '\n        repr_str += f'normalize={self.normalize}, '\n        repr_str += f'scale={self.scale}, '\n        repr_str += f'eps={self.eps})'\n        return repr_str\n\n\n@POSITIONAL_ENCODING.register_module()\nclass LearnedPositionalEncoding(BaseModule):\n    \"\"\"Position embedding with learnable embedding weights.\n\n    Args:\n        num_feats (int): The feature dimension for each position\n            along x-axis or y-axis. The final returned dimension for\n            each position is 2 times of this value.\n        row_num_embed (int, optional): The dictionary size of row embeddings.\n            Default 50.\n        col_num_embed (int, optional): The dictionary size of col embeddings.\n            Default 50.\n        init_cfg (dict or list[dict], optional): Initialization config dict.\n    \"\"\"\n\n    def __init__(self,\n                 num_feats,\n                 row_num_embed=50,\n                 col_num_embed=50,\n                 init_cfg=dict(type='Uniform', layer='Embedding')):\n        super(LearnedPositionalEncoding, self).__init__(init_cfg)\n        self.row_embed = nn.Embedding(row_num_embed, num_feats)\n        self.col_embed = nn.Embedding(col_num_embed, num_feats)\n        self.num_feats = num_feats\n        self.row_num_embed = row_num_embed\n        self.col_num_embed = col_num_embed\n\n    def forward(self, mask):\n        \"\"\"Forward function for `LearnedPositionalEncoding`.\n\n        Args:\n            mask (Tensor): ByteTensor mask. Non-zero values representing\n                ignored positions, while zero values means valid positions\n                for this image. Shape [bs, h, w].\n\n        Returns:\n            pos (Tensor): Returned position embedding with shape\n                [bs, num_feats*2, h, w].\n        \"\"\"\n        h, w = mask.shape[-2:]\n        x = torch.arange(w, device=mask.device)\n        y = torch.arange(h, device=mask.device)\n        x_embed = self.col_embed(x)\n        y_embed = self.row_embed(y)\n        pos = torch.cat(\n            (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(\n                1, w, 1)),\n            dim=-1).permute(2, 0,\n                            1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)\n        return pos\n\n    def __repr__(self):\n        \"\"\"str: a string that describes the module\"\"\"\n        repr_str = self.__class__.__name__\n        repr_str += f'(num_feats={self.num_feats}, '\n        repr_str += f'row_num_embed={self.row_num_embed}, '\n        repr_str += f'col_num_embed={self.col_num_embed})'\n        return repr_str\n"
  },
  {
    "path": "models/models/utils/transformer.py",
    "content": "import torch\nimport torch.nn as nn\nfrom models.models.utils.builder import TRANSFORMER\nfrom mmcv.cnn import (build_activation_layer, build_norm_layer, xavier_init)\nfrom mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,\n                                      TRANSFORMER_LAYER_SEQUENCE)\nfrom mmcv.cnn.bricks.transformer import (BaseTransformerLayer,\n                                         TransformerLayerSequence,\n                                         build_transformer_layer_sequence)\nfrom mmcv.runner.base_module import BaseModule\n\n\n@TRANSFORMER.register_module()\nclass Transformer(BaseModule):\n    \"\"\"Implements the DETR transformer.\n    Following the official DETR implementation, this module copy-paste\n    from torch.nn.Transformer with modifications:\n        * positional encodings are passed in MultiheadAttention\n        * extra LN at the end of encoder is removed\n        * decoder returns a stack of activations from all decoding layers\n    See `paper: End-to-End Object Detection with Transformers\n    <https://arxiv.org/pdf/2005.12872>`_ for details.\n    Args:\n        encoder (`mmcv.ConfigDict` | Dict): Config of\n            TransformerEncoder. Defaults to None.\n        decoder ((`mmcv.ConfigDict` | Dict)): Config of\n            TransformerDecoder. Defaults to None\n        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.\n            Defaults to None.\n    \"\"\"\n\n    def __init__(self, encoder=None, decoder=None, init_cfg=None):\n        super(Transformer, self).__init__(init_cfg=init_cfg)\n        self.encoder = build_transformer_layer_sequence(encoder)\n        self.decoder = build_transformer_layer_sequence(decoder)\n        self.embed_dims = self.encoder.embed_dims\n\n    def init_weights(self):\n        # follow the official DETR to init parameters\n        for m in self.modules():\n            if hasattr(m, 'weight') and m.weight.dim() > 1:\n                xavier_init(m, distribution='uniform')\n        self._is_init = True\n\n    def forward(self, x, mask, query_embed, pos_embed, mask_query):\n        \"\"\"Forward function for `Transformer`.\n        Args:\n            x (Tensor): Input query with shape [bs, c, h, w] where\n                c = embed_dims.\n            mask (Tensor): The key_padding_mask used for encoder and decoder,\n                with shape [bs, h, w].\n            query_embed (Tensor): The query embedding for decoder, with shape\n                [num_query, c].\n            pos_embed (Tensor): The positional encoding for encoder and\n                decoder, with the same shape as `x`.\n        Returns:\n            tuple[Tensor]: results of decoder containing the following tensor.\n                - out_dec: Output from decoder. If return_intermediate_dec \\\n                      is True output has shape [num_dec_layers, bs,\n                      num_query, embed_dims], else has shape [1, bs, \\\n                      num_query, embed_dims].\n                - memory: Output results from encoder, with shape \\\n                      [bs, embed_dims, h, w].\n\n        Notes:\n            x: query image features with shape [bs, c, h, w]\n            mask: mask for x with shape [bs, h, w]\n            pos_embed: positional embedding for x with shape [bs, c, h, w]\n            query_embed: sample keypoint features with shape [bs, num_query, c]\n            mask_query: mask for query_embed with shape [bs, num_query]\n        Outputs:\n            out_dec: [num_layers, bs, num_query, c]\n            memory: [bs, c, h, w]\n\n        \"\"\"\n        bs, c, h, w = x.shape\n        # use `view` instead of `flatten` for dynamically exporting to ONNX\n        x = x.view(bs, c, -1).permute(2, 0, 1)  # [bs, c, h, w] -> [h*w, bs, c]\n        mask = mask.view(bs,\n                         -1)  # [bs, h, w] -> [bs, h*w] Note: this mask should be filled with False, since all images are with the same shape.\n        pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)  # positional embeding for memory, i.e., the query.\n        memory = self.encoder(\n            query=x,\n            key=None,\n            value=None,\n            query_pos=pos_embed,\n            query_key_padding_mask=mask)  # output memory: [hw, bs, c]\n\n        query_embed = query_embed.permute(1, 0, 2)  # [bs, num_query, c] -> [num_query, bs, c]\n        # target = torch.zeros_like(query_embed)\n        # out_dec: [num_layers, num_query, bs, c]\n        out_dec = self.decoder(\n            query=query_embed,\n            key=memory,\n            value=memory,\n            key_pos=pos_embed,\n            # query_pos=query_embed,\n            query_key_padding_mask=mask_query,\n            key_padding_mask=mask)\n        out_dec = out_dec.transpose(1, 2)  # [decoder_layer, bs, num_query, c]\n        memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)\n        return out_dec, memory\n\n\n@TRANSFORMER_LAYER.register_module()\nclass DetrTransformerDecoderLayer(BaseTransformerLayer):\n    \"\"\"Implements decoder layer in DETR transformer.\n    Args:\n        attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):\n            Configs for self_attention or cross_attention, the order\n            should be consistent with it in `operation_order`. If it is\n            a dict, it would be expand to the number of attention in\n            `operation_order`.\n        feedforward_channels (int): The hidden dimension for FFNs.\n        ffn_dropout (float): Probability of an element to be zeroed\n            in ffn. Default 0.0.\n        operation_order (tuple[str]): The execution order of operation\n            in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').\n            Default：None\n        act_cfg (dict): The activation config for FFNs. Default: `LN`\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: `LN`.\n        ffn_num_fcs (int): The number of fully-connected layers in FFNs.\n            Default：2.\n    \"\"\"\n\n    def __init__(self,\n                 attn_cfgs,\n                 feedforward_channels,\n                 ffn_dropout=0.0,\n                 operation_order=None,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 norm_cfg=dict(type='LN'),\n                 ffn_num_fcs=2,\n                 **kwargs):\n        super(DetrTransformerDecoderLayer, self).__init__(\n            attn_cfgs=attn_cfgs,\n            feedforward_channels=feedforward_channels,\n            ffn_dropout=ffn_dropout,\n            operation_order=operation_order,\n            act_cfg=act_cfg,\n            norm_cfg=norm_cfg,\n            ffn_num_fcs=ffn_num_fcs,\n            **kwargs)\n        # assert len(operation_order) == 6\n        # assert set(operation_order) == set(\n        #     ['self_attn', 'norm', 'cross_attn', 'ffn'])\n\n\n@TRANSFORMER_LAYER_SEQUENCE.register_module()\nclass DetrTransformerEncoder(TransformerLayerSequence):\n    \"\"\"TransformerEncoder of DETR.\n    Args:\n        post_norm_cfg (dict): Config of last normalization layer. Default：\n            `LN`. Only used when `self.pre_norm` is `True`\n    \"\"\"\n\n    def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):\n        super(DetrTransformerEncoder, self).__init__(*args, **kwargs)\n        if post_norm_cfg is not None:\n            self.post_norm = build_norm_layer(\n                post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None\n        else:\n            # assert not self.pre_norm, f'Use prenorm in ' \\\n            #                           f'{self.__class__.__name__},' \\\n            #                           f'Please specify post_norm_cfg'\n            self.post_norm = None\n\n    def forward(self, *args, **kwargs):\n        \"\"\"Forward function for `TransformerCoder`.\n        Returns:\n            Tensor: forwarded results with shape [num_query, bs, embed_dims].\n        \"\"\"\n        x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)\n        if self.post_norm is not None:\n            x = self.post_norm(x)\n        return x\n\n\n@TRANSFORMER_LAYER_SEQUENCE.register_module()\nclass DetrTransformerDecoder(TransformerLayerSequence):\n    \"\"\"Implements the decoder in DETR transformer.\n    Args:\n        return_intermediate (bool): Whether to return intermediate outputs.\n        post_norm_cfg (dict): Config of last normalization layer. Default：\n            `LN`.\n    \"\"\"\n\n    def __init__(self,\n                 *args,\n                 post_norm_cfg=dict(type='LN'),\n                 return_intermediate=False,\n                 **kwargs):\n\n        super(DetrTransformerDecoder, self).__init__(*args, **kwargs)\n        self.return_intermediate = return_intermediate\n        if post_norm_cfg is not None:\n            self.post_norm = build_norm_layer(post_norm_cfg,\n                                              self.embed_dims)[1]\n        else:\n            self.post_norm = None\n\n    def forward(self, query, *args, **kwargs):\n        \"\"\"Forward function for `TransformerDecoder`.\n        Args:\n            query (Tensor): Input query with shape\n                `(num_query, bs, embed_dims)`.\n        Returns:\n            Tensor: Results with shape [1, num_query, bs, embed_dims] when\n                return_intermediate is `False`, otherwise it has shape\n                [num_layers, num_query, bs, embed_dims].\n        \"\"\"\n        if not self.return_intermediate:\n            x = super().forward(query, *args, **kwargs)\n            if self.post_norm:\n                x = self.post_norm(x)[None]\n            return x\n\n        intermediate = []\n        for layer in self.layers:\n            query = layer(query, *args, **kwargs)\n            if self.return_intermediate:\n                if self.post_norm is not None:\n                    intermediate.append(self.post_norm(query))\n                else:\n                    intermediate.append(query)\n        return torch.stack(intermediate)\n\n\n@TRANSFORMER.register_module()\nclass DynamicConv(BaseModule):\n    \"\"\"Implements Dynamic Convolution.\n    This module generate parameters for each sample and\n    use bmm to implement 1*1 convolution. Code is modified\n    from the `official github repo <https://github.com/PeizeSun/\n    SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .\n    Args:\n        in_channels (int): The input feature channel.\n            Defaults to 256.\n        feat_channels (int): The inner feature channel.\n            Defaults to 64.\n        out_channels (int, optional): The output feature channel.\n            When not specified, it will be set to `in_channels`\n            by default\n        input_feat_shape (int): The shape of input feature.\n            Defaults to 7.\n        with_proj (bool): Project two-dimentional feature to\n            one-dimentional feature. Default to True.\n        act_cfg (dict): The activation config for DynamicConv.\n        norm_cfg (dict): Config dict for normalization layer. Default\n            layer normalization.\n        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.\n            Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels=256,\n                 feat_channels=64,\n                 out_channels=None,\n                 input_feat_shape=7,\n                 with_proj=True,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 norm_cfg=dict(type='LN'),\n                 init_cfg=None):\n        super(DynamicConv, self).__init__(init_cfg)\n        self.in_channels = in_channels\n        self.feat_channels = feat_channels\n        self.out_channels_raw = out_channels\n        self.input_feat_shape = input_feat_shape\n        self.with_proj = with_proj\n        self.act_cfg = act_cfg\n        self.norm_cfg = norm_cfg\n        self.out_channels = out_channels if out_channels else in_channels\n\n        self.num_params_in = self.in_channels * self.feat_channels\n        self.num_params_out = self.out_channels * self.feat_channels\n        self.dynamic_layer = nn.Linear(\n            self.in_channels, self.num_params_in + self.num_params_out)\n\n        self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]\n        self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]\n\n        self.activation = build_activation_layer(act_cfg)\n\n        num_output = self.out_channels * input_feat_shape ** 2\n        if self.with_proj:\n            self.fc_layer = nn.Linear(num_output, self.out_channels)\n            self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]\n\n    def forward(self, param_feature, input_feature):\n        \"\"\"Forward function for `DynamicConv`.\n        Args:\n            param_feature (Tensor): The feature can be used\n                to generate the parameter, has shape\n                (num_all_proposals, in_channels).\n            input_feature (Tensor): Feature that\n                interact with parameters, has shape\n                (num_all_proposals, in_channels, H, W).\n        Returns:\n            Tensor: The output feature has shape\n            (num_all_proposals, out_channels).\n        \"\"\"\n        input_feature = input_feature.flatten(2).permute(2, 0, 1)\n\n        input_feature = input_feature.permute(1, 0, 2)\n        parameters = self.dynamic_layer(param_feature)\n\n        param_in = parameters[:, :self.num_params_in].view(\n            -1, self.in_channels, self.feat_channels)\n        param_out = parameters[:, -self.num_params_out:].view(\n            -1, self.feat_channels, self.out_channels)\n\n        # input_feature has shape (num_all_proposals, H*W, in_channels)\n        # param_in has shape (num_all_proposals, in_channels, feat_channels)\n        # feature has shape (num_all_proposals, H*W, feat_channels)\n        features = torch.bmm(input_feature, param_in)\n        features = self.norm_in(features)\n        features = self.activation(features)\n\n        # param_out has shape (batch_size, feat_channels, out_channels)\n        features = torch.bmm(features, param_out)\n        features = self.norm_out(features)\n        features = self.activation(features)\n\n        if self.with_proj:\n            features = features.flatten(1)\n            features = self.fc_layer(features)\n            features = self.fc_norm(features)\n            features = self.activation(features)\n\n        return features\n"
  },
  {
    "path": "models/version.py",
    "content": "# GENERATED VERSION FILE\n# TIME: Tue Dec 19 17:01:21 2023\n__version__ = '0.2.0+f65cb07'\nshort_version = '0.2.0'\nversion_info = (0, 2, 0)\n"
  },
  {
    "path": "requirements.txt",
    "content": "json_tricks\nnumpy\nopencv-python\npillow==6.2.2\nxtcocotools\nscipy"
  },
  {
    "path": "setup.cfg",
    "content": "[bdist_wheel]\nuniversal=1\n\n[aliases]\ntest=pytest\n\n[tool:pytest]\naddopts=tests/\n\n[yapf]\nbased_on_style = pep8\nblank_line_before_nested_class_or_def = true\nsplit_before_expression_after_opening_paren = true\n\n[isort]\nline_length = 79\nmulti_line_output = 0\nknown_standard_library = pkg_resources,setuptools\nknown_first_party = mmpose\nknown_third_party = cv2,json_tricks,mmcv,mmdet,munkres,numpy,xtcocotools,torch\nno_lines_before = STDLIB,LOCALFOLDER\ndefault_section = THIRDPARTY\n"
  },
  {
    "path": "setup.py",
    "content": "import os\nimport subprocess\nimport time\nfrom setuptools import find_packages, setup\n\n\ndef readme():\n    with open('README.md', encoding='utf-8') as f:\n        content = f.read()\n    return content\n\n\nversion_file = 'models/version.py'\n\n\ndef get_git_hash():\n\n    def _minimal_ext_cmd(cmd):\n        # construct minimal environment\n        env = {}\n        for k in ['SYSTEMROOT', 'PATH', 'HOME']:\n            v = os.environ.get(k)\n            if v is not None:\n                env[k] = v\n        # LANGUAGE is used on win32\n        env['LANGUAGE'] = 'C'\n        env['LANG'] = 'C'\n        env['LC_ALL'] = 'C'\n        out = subprocess.Popen(\n            cmd, stdout=subprocess.PIPE, env=env).communicate()[0]\n        return out\n\n    try:\n        out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])\n        sha = out.strip().decode('ascii')\n    except OSError:\n        sha = 'unknown'\n\n    return sha\n\n\ndef get_hash():\n    if os.path.exists('.git'):\n        sha = get_git_hash()[:7]\n    elif os.path.exists(version_file):\n        try:\n            from models.version import __version__\n            sha = __version__.split('+')[-1]\n        except ImportError:\n            raise ImportError('Unable to get git version')\n    else:\n        sha = 'unknown'\n\n    return sha\n\n\ndef write_version_py():\n    content = \"\"\"# GENERATED VERSION FILE\n# TIME: {}\n__version__ = '{}'\nshort_version = '{}'\nversion_info = ({})\n\"\"\"\n    sha = get_hash()\n    with open('models/VERSION', 'r') as f:\n        SHORT_VERSION = f.read().strip()\n    VERSION_INFO = ', '.join(SHORT_VERSION.split('.'))\n    VERSION = SHORT_VERSION + '+' + sha\n\n    version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,\n                                      VERSION_INFO)\n    with open(version_file, 'w') as f:\n        f.write(version_file_str)\n\n\ndef get_version():\n    with open(version_file, 'r') as f:\n        exec(compile(f.read(), version_file, 'exec'))\n    return locals()['__version__']\n\n\ndef get_requirements(filename='requirements.txt'):\n    here = os.path.dirname(os.path.realpath(__file__))\n    with open(os.path.join(here, filename), 'r') as f:\n        requires = [line.replace('\\n', '') for line in f.readlines()]\n    return requires\n\n\nif __name__ == '__main__':\n    write_version_py()\n    setup(\n        name='pose_anything',\n        version=get_version(),\n        description='A template for pytorch projects.',\n        long_description=readme(),\n        packages=find_packages(exclude=('configs', 'tools', 'demo')),\n        package_data={'pose_anything.ops': ['*/*.so']},\n        classifiers=[\n            'Development Status :: 4 - Beta',\n            'License :: OSI Approved :: Apache Software License',\n            'Operating System :: OS Independent',\n            'Programming Language :: Python :: 3',\n            'Programming Language :: Python :: 3.5',\n            'Programming Language :: Python :: 3.6',\n            'Programming Language :: Python :: 3.7',\n        ],\n        license='Apache License 2.0',\n        setup_requires=['pytest-runner', 'cython', 'numpy'],\n        tests_require=['pytest', 'xdoctest'],\n        install_requires=get_requirements(),\n        zip_safe=False)\n"
  },
  {
    "path": "test.py",
    "content": "import argparse\nimport os\nimport os.path as osp\nimport random\nimport uuid\n\nimport mmcv\nimport numpy as np\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import get_dist_info, init_dist, load_checkpoint\nfrom models import *  # noqa\nfrom models.datasets import build_dataset\n\nfrom mmpose.apis import multi_gpu_test, single_gpu_test\nfrom mmpose.core import wrap_fp16_model\nfrom mmpose.datasets import build_dataloader\nfrom mmpose.models import build_posenet\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='mmpose test model')\n    parser.add_argument('config', default=None, help='test config file path')\n    parser.add_argument('checkpoint', default=None, help='checkpoint file')\n    parser.add_argument('--out', help='output result file')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase the inference speed')\n    parser.add_argument(\n        '--eval',\n        default=None,\n        nargs='+',\n        help='evaluation metric, which depends on the dataset,'\n        ' e.g., \"mAP\" for MSCOCO')\n    parser.add_argument(\n        '--permute_keypoints',\n        action='store_true',\n        help='whether to randomly permute keypoints')\n    parser.add_argument(\n        '--gpu_collect',\n        action='store_true',\n        help='whether to use gpu to collect results')\n    parser.add_argument('--tmpdir', help='tmp dir for writing some results')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        default={},\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. For example, '\n        \"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'\")\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n    return args\n\n\ndef merge_configs(cfg1, cfg2):\n    # Merge cfg2 into cfg1\n    # Overwrite cfg1 if repeated, ignore if value is None.\n    cfg1 = {} if cfg1 is None else cfg1.copy()\n    cfg2 = {} if cfg2 is None else cfg2\n    for k, v in cfg2.items():\n        if v:\n            cfg1[k] = v\n    return cfg1\n\n\ndef main():\n    random.seed(0)\n    np.random.seed(0)\n    torch.manual_seed(0)\n    uuid.UUID(int=0)\n\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    # cfg.model.pretrained = None\n    cfg.data.test.test_mode = True\n\n    args.work_dir = osp.join('./work_dirs',\n                             osp.splitext(osp.basename(args.config))[0])\n    mmcv.mkdir_or_exist(osp.abspath(args.work_dir))\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test, dict(test_mode=True))\n    dataloader_setting = dict(\n        samples_per_gpu=1,\n        workers_per_gpu=cfg.data.get('workers_per_gpu', 12),\n        dist=distributed,\n        shuffle=False,\n        drop_last=False)\n    dataloader_setting = dict(dataloader_setting,\n                              **cfg.data.get('test_dataloader', {}))\n    data_loader = build_dataloader(dataset, **dataloader_setting)\n\n    # build the model and load checkpoint\n    model = build_posenet(cfg.model)\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    load_checkpoint(model, args.checkpoint, map_location='cpu')\n\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=[0])\n        outputs = single_gpu_test(model, data_loader)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect)\n\n    rank, _ = get_dist_info()\n    eval_config = cfg.get('evaluation', {})\n    eval_config = merge_configs(eval_config, dict(metric=args.eval))\n\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n\n        results = dataset.evaluate(outputs, **eval_config)\n        print('\\n')\n        for k, v in sorted(results.items()):\n            print(f'{k}: {v}')\n\n        # save testing log\n        test_log = \"./work_dirs/testing_log.txt\"\n        with open(test_log, 'a') as f:\n            f.write(\"**  config_file: \" + args.config + \"\\t checkpoint: \" + args.checkpoint + \"\\t \\n\")\n            for k, v in sorted(results.items()):\n                f.write(f'\\t {k}: {v}'+'\\n')\n            f.write(\"********************************************************************\\n\")\n        \nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/dist_test.sh",
    "content": "#!/usr/bin/env bash\n# Copyright (c) OpenMMLab. All rights reserved.\n\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=${PORT:-29000}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}\n"
  },
  {
    "path": "tools/dist_train.sh",
    "content": "#!/usr/bin/env bash\n# Copyright (c) OpenMMLab. All rights reserved.\n\nCONFIG=$1\nGPUS=$2\nOUTPUT_DIR=$3\nPORT=${PORT:-29000}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\\n    $(dirname \"$0\")/train.py $CONFIG --work-dir $OUTPUT_DIR --launcher pytorch ${@:3}\n"
  },
  {
    "path": "tools/fix_carfuxion.py",
    "content": "import json\nimport os\nimport shutil\nimport sys\nimport numpy as np\nfrom xtcocotools.coco import COCO\n\n\ndef search_match(bbox, num_keypoints, segmentation):\n    found = []\n    checked = 0\n    for json_file, coco in COCO_DICT.items():\n        cat_ids = coco.getCatIds()\n        for cat_id in cat_ids:\n            img_ids = coco.getImgIds(catIds=cat_id)\n            for img_id in img_ids:\n                annotations = coco.loadAnns(coco.getAnnIds(imgIds=img_id, catIds=cat_id))\n                for ann in annotations:\n                    checked += 1\n                    if (ann['num_keypoints'] == num_keypoints and ann['bbox'] == bbox and ann[\n                        'segmentation'] == segmentation):\n                        src_file = coco.loadImgs(img_id)[0][\"file_name\"]\n                        split = \"test\" if \"test\" in json_file else \"train\"\n                        found.append((src_file, ann, split))\n                        # return src_file, ann, split\n    if len(found) == 0:\n        raise Exception(\"No match found out of {} images\".format(checked))\n    elif len(found) > 1:\n        raise Exception(\"More than one match! \".format(found))\n    return found[0]\n\nif __name__ == \"__main__\":\n\n    carfusion_dir_path = sys.argv[1]\n    mp100_dataset_path = sys.argv[2]\n    os.makedirs('output', exist_ok=True)\n    for cat in ['car', 'bus', 'suv']:\n        os.makedirs(os.path.join('output', cat), exist_ok=True)\n\n\n    COCO_DICT = {}\n    ann_files = os.path.join(carfusion_dir_path, 'annotations')\n    for json_file in os.listdir(ann_files):\n        COCO_DICT[json_file] = COCO(os.path.join(carfusion_dir_path, 'annotations', json_file))\n\n    count = 0\n    print_log = []\n    for json_file in os.listdir(mp100_dataset_path):\n        print(\"Processing {}\".format(json_file))\n        cats = {}\n        coco = COCO(os.path.join(mp100_dataset_path, json_file))\n        cat_ids = coco.getCatIds()\n        for cat_id in cat_ids:\n            category_info = coco.loadCats(cat_id)\n            cat_name = category_info[0]['name']\n            if cat_name in ['car', 'bus', 'suv']:\n                cats[cat_name] = cat_id\n\n\n        for cat_name, cat_id in cats.items():\n            img_ids = coco.getImgIds(catIds=cat_id)\n            count += len(img_ids)\n            print_log.append(f'{json_file} : {cat_name}: {len(img_ids)}')\n            for img_id in img_ids:\n                img = coco.loadImgs(img_id)[0]\n                dst_file_name = img['file_name']\n                annotation = coco.loadAnns(coco.getAnnIds(imgIds=img_id, catIds=cat_id, iscrowd=None))\n                bbox = annotation[0]['bbox']\n                keypoints = annotation[0]['keypoints']\n                segmentation = annotation[0]['segmentation']\n                num_keypoints = annotation[0]['num_keypoints']\n\n                # Search for a match:\n                src_img, src_ann, split = search_match(bbox, num_keypoints, segmentation)\n                shutil.copyfile(\n                    os.path.join(carfusion_dir_path, split, src_img),\n                    os.path.join('output', dst_file_name))"
  },
  {
    "path": "tools/slurm_test.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nCHECKPOINT=$4\nGPUS=${GPUS:-8}\nGPUS_PER_NODE=${GPUS_PER_NODE:-8}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u test.py ${CONFIG} ${CHECKPOINT} --launcher=\"slurm\" ${PY_ARGS}\n"
  },
  {
    "path": "tools/slurm_train.sh",
    "content": "#!/usr/bin/env bash\n\nset -x\n\nPARTITION=$1\nJOB_NAME=$2\nCONFIG=$3\nWORK_DIR=$4\nGPUS=${GPUS:-8}\nGPUS_PER_NODE=${GPUS_PER_NODE:-8}\nCPUS_PER_TASK=${CPUS_PER_TASK:-5}\nPY_ARGS=${@:5}\nSRUN_ARGS=${SRUN_ARGS:-\"\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\nsrun -p ${PARTITION} \\\n    --job-name=${JOB_NAME} \\\n    --gres=gpu:${GPUS_PER_NODE} \\\n    --ntasks=${GPUS} \\\n    --ntasks-per-node=${GPUS_PER_NODE} \\\n    --cpus-per-task=${CPUS_PER_TASK} \\\n    --kill-on-bad-exit=1 \\\n    ${SRUN_ARGS} \\\n    python -u train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher=\"slurm\" ${PY_ARGS}\n"
  },
  {
    "path": "tools/visualization.py",
    "content": "import os\nimport random\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport uuid\n\ncolors = [\n    [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],\n    [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],\n    [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],\n    [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]]\n\n\ndef plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w, skeleton,\n                 initial_proposals, prediction, radius=6, out_dir='./heatmaps'):\n    img_names = [img.split(\"_\")[0] for img in os.listdir(out_dir) if str_is_int(img.split(\"_\")[0])]\n    if len(img_names) > 0:\n        name_idx = max([int(img_name) for img_name in img_names]) + 1\n    else:\n        name_idx = 0\n\n    h, w, c = support_img.shape\n    prediction = prediction[-1].cpu().numpy() * h\n    support_img = (support_img - np.min(support_img)) / (np.max(support_img) - np.min(support_img))\n    query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))\n\n    for id, (img, w, keypoint) in enumerate(zip([support_img, query_img],\n                                                [support_w, query_w],\n                                                [support_kp, prediction])):\n        f, axes = plt.subplots()\n        plt.imshow(img)\n        for k in range(keypoint.shape[0]):\n            if w[k] > 0:\n                kp = keypoint[k, :2]\n                c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)\n                patch = plt.Circle(kp, radius, color=c)\n                axes.add_patch(patch)\n                axes.text(kp[0], kp[1], k)\n                plt.draw()\n        for l, limb in enumerate(skeleton):\n            kp = keypoint[:, :2]\n            if l > len(colors) - 1:\n                c = [x / 255 for x in random.sample(range(0, 255), 3)]\n            else:\n                c = [x / 255 for x in colors[l]]\n            if w[limb[0]] > 0 and w[limb[1]] > 0:\n                patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]],\n                                   [kp[limb[0], 1], kp[limb[1], 1]],\n                                   linewidth=6, color=c, alpha=0.6)\n                axes.add_artist(patch)\n        plt.axis('off')  # command for hiding the axis.\n        name = 'support' if id == 0 else 'query'\n        plt.savefig(f'./{out_dir}/{str(name_idx)}_{str(name)}.png', bbox_inches='tight', pad_inches=0)\n        if id == 1:\n            plt.show()\n        plt.clf()\n        plt.close('all')\n\n\ndef str_is_int(s):\n    try:\n        int(s)\n        return True\n    except ValueError:\n        return False\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport copy\nimport os\nimport os.path as osp\nimport time\n\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.runner import get_dist_info, init_dist, set_random_seed\nfrom mmcv.utils import get_git_hash\n\nfrom models import *  # noqa\nfrom models.apis import train_model\nfrom models.datasets import build_dataset\n\nfrom mmpose import __version__\nfrom mmpose.models import build_posenet\nfrom mmpose.utils import collect_env, get_root_logger\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a pose model')\n    parser.add_argument('--config', default=None, help='train config file path')\n    parser.add_argument('--work-dir', default=None, help='the dir to save logs and models')\n    parser.add_argument(\n        '--resume-from', help='the checkpoint file to resume from')\n    parser.add_argument(\n        '--auto-resume', type=bool, default=True, help='automatically detect the latest checkpoint in word dir and resume from it.')\n    parser.add_argument(\n        '--no-validate',\n        action='store_true',\n        help='whether not to evaluate the checkpoint during training')\n    group_gpus = parser.add_mutually_exclusive_group()\n    group_gpus.add_argument(\n        '--gpus',\n        type=int,\n        help='number of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed training)')\n    parser.add_argument('--seed', type=int, default=None, help='random seed')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.') \n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        default={},\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. For example, '\n        \"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'\")\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    parser.add_argument(\n        '--autoscale-lr',\n        action='store_true',\n        help='automatically scale lr with the number of gpus')\n    parser.add_argument(\n        '--show',\n        action='store_true',\n        help='whether to display the prediction results in a window.')\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    # work_dir is determined in this priority: CLI \n    # > segment in file > filename\n    if args.work_dir is not None:\n        # update configs according to CLI args if args.work_dir is not None\n        cfg.work_dir = args.work_dir\n    elif cfg.get('work_dir', None) is None:\n        # use config filename as default work_dir if cfg.work_dir is None\n        cfg.work_dir = osp.join('./work_dirs',\n                                osp.splitext(osp.basename(args.config))[0])\n    # auto resume\n    if args.auto_resume:\n        checkpoint = os.path.join(args.work_dir, 'latest.pth')\n        if os.path.exists(checkpoint):\n            cfg.resume_from = checkpoint\n    \n    if args.resume_from is not None:\n        cfg.resume_from = args.resume_from\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)\n\n    if args.autoscale_lr:\n        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)\n        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n        # re-set gpu_ids with distributed training mode\n        _, world_size = get_dist_info()\n        cfg.gpu_ids = range(world_size)\n\n    # create work_dir\n    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n    # init the logger before other steps\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)\n\n    # init the meta dict to record some important information such as\n    # environment info and seed, which will be logged\n    meta = dict()\n    # log env info\n    env_info_dict = collect_env()\n    env_info = '\\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])\n    dash_line = '-' * 60 + '\\n'\n    logger.info('Environment info:\\n' + dash_line + env_info + '\\n' +\n                dash_line)\n    meta['env_info'] = env_info\n\n    # log some basic info\n    logger.info(f'Distributed training: {distributed}')\n    logger.info(f'Config:\\n{cfg.pretty_text}')\n\n    # set random seeds\n    args.seed = 1\n    args.deterministic = True\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, '\n                    f'deterministic: {args.deterministic}')\n        set_random_seed(args.seed, deterministic=args.deterministic)\n    cfg.seed = args.seed\n    meta['seed'] = args.seed\n\n    model = build_posenet(cfg.model)\n    train_datasets = [build_dataset(cfg.data.train)]\n\n    # if len(cfg.workflow) == 2:\n    #     val_dataset = copy.deepcopy(cfg.data.val)\n    #     val_dataset.pipeline = cfg.data.train.pipeline\n    #     datasets.append(build_dataset(val_dataset))\n\n    val_dataset = copy.deepcopy(cfg.data.val)\n    val_dataset = build_dataset(val_dataset, dict(test_mode=True))\n\n    if cfg.checkpoint_config is not None:\n        # save mmpose version, config file content\n        # checkpoints as meta data\n        cfg.checkpoint_config.meta = dict(\n            mmpose_version=__version__ + get_git_hash(digits=7),\n            config=cfg.pretty_text,\n        )\n    train_model(\n        model,\n        train_datasets,\n        val_dataset,\n        cfg,\n        distributed=distributed,\n        validate=(not args.no_validate),\n        timestamp=timestamp,\n        meta=meta)\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]